Unverified Commit 29051439 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Lint] Phaseout Yapf format and embrace ruff format (#1417)

parent e84b24bc
......@@ -8,6 +8,7 @@ import argparse
def get_configs():
import itertools
BLOCK_N = [16, 32, 64, 128]
BLOCK_H = [16, 32, 64, 128]
num_split = [1, 2, 4, 8, 16, 32]
......@@ -15,30 +16,26 @@ def get_configs():
_configs = list(itertools.product(BLOCK_N, BLOCK_H, num_split, threads))
return [{
"block_N": c[0],
"block_H": c[1],
"num_split": c[2],
"threads": c[3],
} for c in _configs]
return [
{
"block_N": c[0],
"block_H": c[1],
"num_split": c[2],
"threads": c[3],
}
for c in _configs
]
@tilelang.autotune(configs=get_configs())
@tilelang.jit(
out_idx=[6], pass_configs={
out_idx=[6],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashmla_decode(batch,
heads,
kv_head_num,
seqlen_kv,
dim,
pe_dim,
block_N,
block_H,
num_split,
threads=128):
scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e)
},
)
def flashmla_decode(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split, threads=128):
scale = (1.0 / (dim + pe_dim)) ** 0.5 * 1.44269504 # log2(e)
dtype = "float16"
accum_dtype = "float"
kv_group_num = heads // kv_head_num
......@@ -47,11 +44,11 @@ def flashmla_decode(batch,
@T.macro
def flash_attn(
Q: T.Tensor([batch, heads, dim], dtype),
Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype),
Q: T.Tensor([batch, heads, dim], dtype),
Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype),
):
with T.Kernel(batch, heads // min(block_H, kv_group_num), threads=threads) as (bx, by):
Q_local = T.alloc_fragment([block_H, dim], dtype)
......@@ -70,24 +67,19 @@ def flashmla_decode(batch,
cur_kv_head = by // (kv_group_num // block_H)
T.use_swizzle(10)
T.copy(Q[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_local)
T.copy(Q_pe[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_pe_local)
T.copy(Q[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_local)
T.copy(Q_pe[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_pe_local)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = T.ceildiv(seqlen_kv, block_N)
for k in T.Pipelined(loop_range, num_stages=0):
T.copy(KV[bx, k * block_N:(k + 1) * block_N, cur_kv_head, :], KV_shared)
T.copy(K_pe[bx, k * block_N:(k + 1) * block_N, cur_kv_head, :], K_pe_shared)
T.copy(KV[bx, k * block_N : (k + 1) * block_N, cur_kv_head, :], KV_shared)
T.copy(K_pe[bx, k * block_N : (k + 1) * block_N, cur_kv_head, :], K_pe_shared)
T.clear(acc_s)
T.gemm(Q_local, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.gemm(
Q_pe_local,
K_pe_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow)
T.gemm(Q_pe_local, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
......@@ -107,20 +99,18 @@ def flashmla_decode(batch,
T.gemm(acc_s_cast, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, Output[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :])
T.copy(acc_o, Output[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :])
@T.macro
def flash_attn_split(
Q: T.Tensor([batch, heads, dim], dtype),
Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
Q: T.Tensor([batch, heads, dim], dtype),
Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
):
with T.Kernel(
batch, heads // min(block_H, kv_group_num), num_split,
threads=threads) as (bx, by, bz):
with T.Kernel(batch, heads // min(block_H, kv_group_num), num_split, threads=threads) as (bx, by, bz):
Q_local = T.alloc_fragment([block_H, dim], dtype)
Q_pe_local = T.alloc_fragment([block_H, pe_dim], dtype)
KV_shared = T.alloc_shared([block_N, dim], dtype)
......@@ -136,8 +126,8 @@ def flashmla_decode(batch,
cur_kv_head = by // (kv_group_num // block_H)
T.use_swizzle(10)
T.copy(Q[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_local)
T.copy(Q_pe[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_pe_local)
T.copy(Q[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_local)
T.copy(Q_pe[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_pe_local)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
......@@ -150,12 +140,7 @@ def flashmla_decode(batch,
T.copy(K_pe[bx, kv_start:kv_end, cur_kv_head, :], K_pe_shared)
T.clear(acc_s)
T.gemm(Q_local, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.gemm(
Q_pe_local,
K_pe_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow)
T.gemm(Q_pe_local, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
......@@ -176,14 +161,14 @@ def flashmla_decode(batch,
acc_o[i, j] /= logsum[i]
for i in T.Parallel(block_H):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(logsum, glse[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, bz])
T.copy(acc_o, Output_partial[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, bz, :])
T.copy(logsum, glse[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, bz])
T.copy(acc_o, Output_partial[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, bz, :])
@T.macro
def combine(
glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype),
glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype),
):
with T.Kernel(heads, batch, threads=128) as (by, bz):
po_local = T.alloc_fragment([dim], dtype)
......@@ -193,9 +178,11 @@ def flashmla_decode(batch,
lse_max_local = T.alloc_local([1], accum_dtype)
scale_local = T.alloc_local([1], accum_dtype)
T.annotate_layout({
lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i),
})
T.annotate_layout(
{
lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i),
}
)
T.clear(lse_logsum_local)
T.clear(o_accum_local)
......@@ -218,26 +205,26 @@ def flashmla_decode(batch,
@T.prim_func
def main_split(
Q: T.Tensor([batch, heads, dim], dtype),
Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype),
Q: T.Tensor([batch, heads, dim], dtype),
Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype),
):
flash_attn_split(Q, Q_pe, KV, K_pe, glse, Output_partial)
combine(glse, Output_partial, Output)
@T.prim_func
def main_no_split(
Q: T.Tensor([batch, heads, dim], dtype),
Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype),
Q: T.Tensor([batch, heads, dim], dtype),
Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype),
):
flash_attn(Q, Q_pe, KV, K_pe, Output)
......@@ -262,43 +249,36 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial):
dim = q.shape[-1]
pe_dim = q_pe.shape[-1]
num_head_groups = q.shape[1] // kv.shape[2]
scale = (dim + pe_dim)**0.5
q = rearrange(
q, 'b (h g) d -> b g h d', g=num_head_groups) # [batch_size, num_head_groups, groups, dim]
scale = (dim + pe_dim) ** 0.5
q = rearrange(q, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, dim]
q_pe = rearrange(
q_pe, 'b (h g) d -> b g h d',
g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim]
q_pe = rearrange(q_pe, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim]
kv = rearrange(kv, 'b n h d -> b h n d') # [batch_size, groups, seqlen_kv, dim]
kv = rearrange(kv, "b n h d -> b h n d") # [batch_size, groups, seqlen_kv, dim]
k_pe = rearrange(k_pe, 'b n h d -> b h n d') # [batch_size, num_head_groups, groups, pe_dim]
k_pe = rearrange(k_pe, "b n h d -> b h n d") # [batch_size, num_head_groups, groups, pe_dim]
query = torch.concat([q, q_pe], dim=-1)
key = torch.concat([kv, k_pe], dim=-1)
scores = einsum(
query, key,
'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, groups, seqlen_kv]
scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, groups, seqlen_kv]
attention = F.softmax(
scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv]
attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv]
out = einsum(attention, kv,
'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, groups, dim]
out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim]
out = einsum(attention, kv, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, groups, dim]
out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim]
return out
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=128, help='batch size')
parser.add_argument('--heads', type=int, default=128, help='q heads number')
parser.add_argument('--kv_heads', type=int, default=1, help='kv heads number')
parser.add_argument('--kv_ctx', type=int, default=8192, help='kv context length')
parser.add_argument('--dim', type=int, default=512, help='head dim')
parser.add_argument('--pe_dim', type=int, default=64, help='pe head dim')
parser.add_argument('--autotune', action='store_true', help='auto tune')
parser.add_argument("--batch", type=int, default=128, help="batch size")
parser.add_argument("--heads", type=int, default=128, help="q heads number")
parser.add_argument("--kv_heads", type=int, default=1, help="kv heads number")
parser.add_argument("--kv_ctx", type=int, default=8192, help="kv context length")
parser.add_argument("--dim", type=int, default=512, help="head dim")
parser.add_argument("--pe_dim", type=int, default=64, help="pe head dim")
parser.add_argument("--autotune", action="store_true", help="auto tune")
args = parser.parse_args()
batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim
enable_autotune = args.autotune
......@@ -314,17 +294,7 @@ if __name__ == "__main__":
if enable_autotune:
kernel = flashmla_decode(batch, heads, kv_heads, kv_ctx, dim, pe_dim)
else:
kernel = flashmla_decode(
batch,
heads,
kv_heads,
kv_ctx,
dim,
pe_dim,
BLOCK_N,
BLOCK_H,
num_split,
threads=threads)
kernel = flashmla_decode(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split, threads=threads)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
input_tensors = profiler._get_inputs()
tilelang_output = kernel(*input_tensors)
......
......@@ -32,8 +32,7 @@ def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
@torch.inference_mode()
def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q,
h_kv, d, dv, causal, dtype):
def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
blocked_v = blocked_k[..., :dv]
def ref_mla():
......@@ -94,8 +93,7 @@ def _mla_attn_kernel(
offs_d_ckv = tl.arange(0, HEAD_DIM_CKV)
cur_head = cur_head_id * BLOCK_H + tl.arange(0, BLOCK_H)
offs_q_nope = cur_batch * stride_q_nope_bs + cur_head[:, None] * stride_q_nope_h + offs_d_ckv[
None, :]
offs_q_nope = cur_batch * stride_q_nope_bs + cur_head[:, None] * stride_q_nope_h + offs_d_ckv[None, :]
q_nope = tl.load(Q_nope + offs_q_nope)
offs_d_kpe = tl.arange(0, HEAD_DIM_KPE)
......@@ -141,9 +139,7 @@ def _mla_attn_kernel(
e_sum = e_sum * re_scale + tl.sum(p, 1)
e_max = n_e_max
offs_o = cur_batch * stride_o_b + cur_head[:,
None] * stride_o_h + split_kv_id * stride_o_s + offs_d_ckv[
None, :]
offs_o = cur_batch * stride_o_b + cur_head[:, None] * stride_o_h + split_kv_id * stride_o_s + offs_d_ckv[None, :]
tl.store(O + offs_o, acc / e_sum[:, None])
offs_o_1 = cur_batch * stride_o_b + cur_head * stride_o_h + split_kv_id * stride_o_s + HEAD_DIM_CKV
tl.store(O + offs_o_1, e_max + tl.log(e_sum))
......@@ -309,24 +305,30 @@ def mla_decode_triton(
@torch.inference_mode()
def run_flash_mla_triton(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q,
cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
def run_flash_mla_triton(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
blocked_v = blocked_k[..., :dv]
assert d > dv, "mla with rope dim should be larger than no rope dim"
q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous()
blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[...,
dv:].contiguous()
blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous()
def flash_mla_triton():
num_kv_splits = 32
o = torch.empty([b * s_q, h_q, dv])
attn_logits = torch.empty([b * s_q, h_q, num_kv_splits, dv + 1])
mla_decode_triton(
q_nope.view(-1, h_q, dv), q_pe.view(-1, h_q, d - dv), blocked_k_nope.view(-1, dv),
blocked_k_pe.view(-1, d - dv), o, block_table, cache_seqlens, attn_logits,
num_kv_splits, 1 / math.sqrt(d), block_size)
q_nope.view(-1, h_q, dv),
q_pe.view(-1, h_q, d - dv),
blocked_k_nope.view(-1, dv),
blocked_k_pe.view(-1, d - dv),
o,
block_table,
cache_seqlens,
attn_logits,
num_kv_splits,
1 / math.sqrt(d),
block_size,
)
return o.view([b, s_q, h_q, dv])
out_flash = flash_mla_triton()
......@@ -362,14 +364,15 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal
q = torch.randn(b, s_q, h_q, d)
block_size = 64
block_table = torch.arange(
b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size)
block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size)
blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d)
out_a, lse_a, perf_a = baseline_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b,
s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype)
out_b, lse_b, perf_b = target_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b,
s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype)
out_a, lse_a, perf_a = baseline_func(
q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype
)
out_b, lse_b, perf_b = target_func(
q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype
)
torch.testing.assert_close(out_b.float(), out_a.float(), atol=1e-2, rtol=1e-2), "out"
if target not in ["flash_mla_triton"]:
......@@ -377,21 +380,14 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal
torch.testing.assert_close(lse_b.float(), lse_a.float(), atol=1e-2, rtol=1e-2), "lse"
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (
torch.finfo(dtype).bits // 8)
print(
f"perf {baseline}: {perf_a:.3f} ms, {FLOPS / 10 ** 9 / perf_a:.0f} TFLOPS, {bytes / 10 ** 6 / perf_a:.0f} GB/s"
)
print(
f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10 ** 9 / perf_b:.0f} TFLOPS, {bytes / 10 ** 6 / perf_b:.0f} GB/s"
)
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8)
print(f"perf {baseline}: {perf_a:.3f} ms, {FLOPS / 10**9 / perf_a:.0f} TFLOPS, {bytes / 10**6 / perf_a:.0f} GB/s")
print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10**9 / perf_b:.0f} TFLOPS, {bytes / 10**6 / perf_b:.0f} GB/s")
return bytes / 10**6 / perf_a, bytes / 10**6 / perf_b
def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
print(
f"{target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}"
)
print(f"{target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}")
torch.set_default_dtype(dtype)
device = torch.device("cuda:0")
torch.set_default_device(device)
......@@ -408,19 +404,16 @@ def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
q = torch.randn(b, s_q, h_q, d)
block_size = 64
block_table = torch.arange(
b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size)
block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size)
blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d)
out_b, lse_b, perf_b = target_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b,
s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype)
out_b, lse_b, perf_b = target_func(
q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype
)
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (
torch.finfo(dtype).bits // 8)
print(
f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10 ** 9 / perf_b:.0f} TFLOPS, {bytes / 10 ** 6 / perf_b:.0f} GB/s"
)
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8)
print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10**9 / perf_b:.0f} TFLOPS, {bytes / 10**6 / perf_b:.0f} GB/s")
return bytes / 10**6 / perf_b
......@@ -429,26 +422,22 @@ available_targets = [
"flash_mla_triton",
]
shape_configs = [{
"b":
batch,
"s_q":
1,
"cache_seqlens":
torch.tensor([seqlen + 2 * i for i in range(batch)], dtype=torch.int32, device="cuda"),
"h_q":
head,
"h_kv":
1,
"d":
512 + 64,
"dv":
512,
"causal":
True,
"dtype":
torch.float16
} for batch in [128] for seqlen in [1024, 2048, 4096, 8192, 16384] for head in [128]]
shape_configs = [
{
"b": batch,
"s_q": 1,
"cache_seqlens": torch.tensor([seqlen + 2 * i for i in range(batch)], dtype=torch.int32, device="cuda"),
"h_q": head,
"h_kv": 1,
"d": 512 + 64,
"dv": 512,
"causal": True,
"dtype": torch.float16,
}
for batch in [128]
for seqlen in [1024, 2048, 4096, 8192, 16384]
for head in [128]
]
def get_args():
......@@ -470,26 +459,54 @@ if __name__ == "__main__":
for shape in shape_configs:
if args.all:
for target in available_targets:
perf = compare_a(target, shape["b"], shape["s_q"], shape["cache_seqlens"],
shape["h_q"], shape["h_kv"], shape["d"], shape["dv"],
shape["causal"], shape["dtype"])
perf = compare_a(
target,
shape["b"],
shape["s_q"],
shape["cache_seqlens"],
shape["h_q"],
shape["h_kv"],
shape["d"],
shape["dv"],
shape["causal"],
shape["dtype"],
)
fout.write(
f'{target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n'
f"{target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perf:.0f}\n"
)
elif args.compare:
perfa, prefb = compare_ab(args.baseline, args.target, shape["b"], shape["s_q"],
shape["cache_seqlens"], shape["h_q"], shape["h_kv"],
shape["d"], shape["dv"], shape["causal"], shape["dtype"])
perfa, prefb = compare_ab(
args.baseline,
args.target,
shape["b"],
shape["s_q"],
shape["cache_seqlens"],
shape["h_q"],
shape["h_kv"],
shape["d"],
shape["dv"],
shape["causal"],
shape["dtype"],
)
fout.write(
f'{args.baseline},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perfa:.0f}\n'
f"{args.baseline},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perfa:.0f}\n"
)
fout.write(
f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{prefb:.0f}\n'
f"{args.target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{prefb:.0f}\n"
)
elif args.one:
perf = compare_a(args.target, shape["b"], shape["s_q"], shape["cache_seqlens"],
shape["h_q"], shape["h_kv"], shape["d"], shape["dv"],
shape["causal"], shape["dtype"])
perf = compare_a(
args.target,
shape["b"],
shape["s_q"],
shape["cache_seqlens"],
shape["h_q"],
shape["h_kv"],
shape["d"],
shape["dv"],
shape["causal"],
shape["dtype"],
)
fout.write(
f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n'
f"{args.target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perf:.0f}\n"
)
......@@ -29,8 +29,7 @@ def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
@torch.inference_mode()
def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q,
h_kv, d, dv, causal, dtype):
def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
blocked_v = blocked_k[..., :dv]
def ref_mla():
......@@ -91,8 +90,7 @@ def _mla_attn_kernel(
offs_d_ckv = tl.arange(0, HEAD_DIM_CKV)
cur_head = cur_head_id * BLOCK_H + tl.arange(0, BLOCK_H)
offs_q_nope = cur_batch * stride_q_nope_bs + cur_head[:, None] * stride_q_nope_h + offs_d_ckv[
None, :]
offs_q_nope = cur_batch * stride_q_nope_bs + cur_head[:, None] * stride_q_nope_h + offs_d_ckv[None, :]
q_nope = tl.load(Q_nope + offs_q_nope)
offs_d_kpe = tl.arange(0, HEAD_DIM_KPE)
......@@ -138,9 +136,7 @@ def _mla_attn_kernel(
e_sum = e_sum * re_scale + tl.sum(p, 1)
e_max = n_e_max
offs_o = cur_batch * stride_o_b + cur_head[:,
None] * stride_o_h + split_kv_id * stride_o_s + offs_d_ckv[
None, :]
offs_o = cur_batch * stride_o_b + cur_head[:, None] * stride_o_h + split_kv_id * stride_o_s + offs_d_ckv[None, :]
tl.store(O + offs_o, acc / e_sum[:, None])
offs_o_1 = cur_batch * stride_o_b + cur_head * stride_o_h + split_kv_id * stride_o_s + HEAD_DIM_CKV
tl.store(O + offs_o_1, e_max + tl.log(e_sum))
......@@ -306,24 +302,30 @@ def mla_decode_triton(
@torch.inference_mode()
def run_flash_mla_triton(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q,
cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
def run_flash_mla_triton(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
blocked_v = blocked_k[..., :dv]
assert d > dv, "mla with rope dim should be larger than no rope dim"
q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous()
blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[...,
dv:].contiguous()
blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous()
def flash_mla_triton():
num_kv_splits = 32
o = torch.empty([b * s_q, h_q, dv])
attn_logits = torch.empty([b * s_q, h_q, num_kv_splits, dv + 1])
mla_decode_triton(
q_nope.view(-1, h_q, dv), q_pe.view(-1, h_q, d - dv), blocked_k_nope.view(-1, dv),
blocked_k_pe.view(-1, d - dv), o, block_table, cache_seqlens, attn_logits,
num_kv_splits, 1 / math.sqrt(d), block_size)
q_nope.view(-1, h_q, dv),
q_pe.view(-1, h_q, d - dv),
blocked_k_nope.view(-1, dv),
blocked_k_pe.view(-1, d - dv),
o,
block_table,
cache_seqlens,
attn_logits,
num_kv_splits,
1 / math.sqrt(d),
block_size,
)
return o.view([b, s_q, h_q, dv])
out_flash = flash_mla_triton()
......@@ -359,14 +361,15 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal
q = torch.randn(b, s_q, h_q, d)
block_size = 64
block_table = torch.arange(
b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size)
block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size)
blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d)
out_a, lse_a, perf_a = baseline_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b,
s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype)
out_b, lse_b, perf_b = target_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b,
s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype)
out_a, lse_a, perf_a = baseline_func(
q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype
)
out_b, lse_b, perf_b = target_func(
q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype
)
torch.testing.assert_close(out_b.float(), out_a.float(), atol=1e-2, rtol=1e-2), "out"
if target not in ["flash_mla_triton"]:
......@@ -374,21 +377,14 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal
torch.testing.assert_close(lse_b.float(), lse_a.float(), atol=1e-2, rtol=1e-2), "lse"
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (
torch.finfo(dtype).bits // 8)
print(
f"perf {baseline}: {perf_a:.3f} ms, {FLOPS / 10 ** 9 / perf_a:.0f} TFLOPS, {bytes / 10 ** 6 / perf_a:.0f} GB/s"
)
print(
f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10 ** 9 / perf_b:.0f} TFLOPS, {bytes / 10 ** 6 / perf_b:.0f} GB/s"
)
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8)
print(f"perf {baseline}: {perf_a:.3f} ms, {FLOPS / 10**9 / perf_a:.0f} TFLOPS, {bytes / 10**6 / perf_a:.0f} GB/s")
print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10**9 / perf_b:.0f} TFLOPS, {bytes / 10**6 / perf_b:.0f} GB/s")
return bytes / 10**6 / perf_a, bytes / 10**6 / perf_b
def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
print(
f"{target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}"
)
print(f"{target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}")
torch.set_default_dtype(dtype)
device = torch.device("cuda:0")
torch.set_default_device(device)
......@@ -405,19 +401,16 @@ def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
q = torch.randn(b, s_q, h_q, d)
block_size = 64
block_table = torch.arange(
b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size)
block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size)
blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d)
out_b, lse_b, perf_b = target_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b,
s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype)
out_b, lse_b, perf_b = target_func(
q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype
)
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (
torch.finfo(dtype).bits // 8)
print(
f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10 ** 9 / perf_b:.0f} TFLOPS, {bytes / 10 ** 6 / perf_b:.0f} GB/s"
)
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8)
print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10**9 / perf_b:.0f} TFLOPS, {bytes / 10**6 / perf_b:.0f} GB/s")
return bytes / 10**6 / perf_b
......@@ -426,26 +419,22 @@ available_targets = [
"flash_mla_triton",
]
shape_configs = [{
"b":
batch,
"s_q":
1,
"cache_seqlens":
torch.tensor([seqlen + 2 * i for i in range(batch)], dtype=torch.int32, device="cuda"),
"h_q":
head,
"h_kv":
1,
"d":
512 + 64,
"dv":
512,
"causal":
True,
"dtype":
torch.float16
} for batch in [64, 128] for seqlen in [1024, 2048, 4096, 8192, 16384] for head in [128]]
shape_configs = [
{
"b": batch,
"s_q": 1,
"cache_seqlens": torch.tensor([seqlen + 2 * i for i in range(batch)], dtype=torch.int32, device="cuda"),
"h_q": head,
"h_kv": 1,
"d": 512 + 64,
"dv": 512,
"causal": True,
"dtype": torch.float16,
}
for batch in [64, 128]
for seqlen in [1024, 2048, 4096, 8192, 16384]
for head in [128]
]
def get_args():
......@@ -467,26 +456,54 @@ if __name__ == "__main__":
for shape in shape_configs:
if args.all:
for target in available_targets:
perf = compare_a(target, shape["b"], shape["s_q"], shape["cache_seqlens"],
shape["h_q"], shape["h_kv"], shape["d"], shape["dv"],
shape["causal"], shape["dtype"])
perf = compare_a(
target,
shape["b"],
shape["s_q"],
shape["cache_seqlens"],
shape["h_q"],
shape["h_kv"],
shape["d"],
shape["dv"],
shape["causal"],
shape["dtype"],
)
fout.write(
f'{target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n'
f"{target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perf:.0f}\n"
)
elif args.compare:
perfa, prefb = compare_ab(args.baseline, args.target, shape["b"], shape["s_q"],
shape["cache_seqlens"], shape["h_q"], shape["h_kv"],
shape["d"], shape["dv"], shape["causal"], shape["dtype"])
perfa, prefb = compare_ab(
args.baseline,
args.target,
shape["b"],
shape["s_q"],
shape["cache_seqlens"],
shape["h_q"],
shape["h_kv"],
shape["d"],
shape["dv"],
shape["causal"],
shape["dtype"],
)
fout.write(
f'{args.baseline},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perfa:.0f}\n'
f"{args.baseline},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perfa:.0f}\n"
)
fout.write(
f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{prefb:.0f}\n'
f"{args.target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{prefb:.0f}\n"
)
elif args.one:
perf = compare_a(args.target, shape["b"], shape["s_q"], shape["cache_seqlens"],
shape["h_q"], shape["h_kv"], shape["d"], shape["dv"],
shape["causal"], shape["dtype"])
perf = compare_a(
args.target,
shape["b"],
shape["s_q"],
shape["cache_seqlens"],
shape["h_q"],
shape["h_kv"],
shape["d"],
shape["dv"],
shape["causal"],
shape["dtype"],
)
fout.write(
f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n'
f"{args.target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perf:.0f}\n"
)
......@@ -33,8 +33,7 @@ def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
@torch.inference_mode()
def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q,
h_kv, d, dv, causal, dtype):
def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
blocked_v = blocked_k[..., :dv]
def ref_mla():
......@@ -61,8 +60,7 @@ def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q,
@torch.inference_mode()
def run_flash_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q,
h_kv, d, dv, causal, dtype):
def run_flash_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
from flash_mla import flash_mla_with_kvcache, get_mla_metadata
blocked_v = blocked_k[..., :dv]
......@@ -87,14 +85,13 @@ def run_flash_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q,
@torch.inference_mode()
def run_flashinfer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens,
h_q, h_kv, d, dv, causal, dtype):
def run_flashinfer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
# pip install flashinfer-python
import flashinfer
assert d > dv, "mla with rope dim should be larger than no rope dim"
q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous()
blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[...,
dv:].contiguous()
blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous()
kv_indptr = [0]
kv_indices = []
......@@ -111,8 +108,7 @@ def run_flashinfer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
mla_wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(
torch.empty(128 * 1024 * 1024, dtype=torch.int8), backend="fa3")
mla_wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(torch.empty(128 * 1024 * 1024, dtype=torch.int8), backend="fa3")
mla_wrapper.plan(
q_indptr,
kv_indptr,
......@@ -129,12 +125,7 @@ def run_flashinfer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q
)
def flashinfer():
output, lse = mla_wrapper.run(
q_nope.view(-1, h_q, dv),
q_pe.view(-1, h_q, d - dv),
blocked_k_nope,
blocked_k_pe,
return_lse=True)
output, lse = mla_wrapper.run(q_nope.view(-1, h_q, dv), q_pe.view(-1, h_q, d - dv), blocked_k_nope, blocked_k_pe, return_lse=True)
return output.view(b, -1, h_q, dv), lse.view(b, h_q, 1)
out_flash, lse_flash = flashinfer()
......@@ -177,8 +168,7 @@ def _mla_attn_kernel(
offs_d_ckv = tl.arange(0, HEAD_DIM_CKV)
cur_head = cur_head_id * BLOCK_H + tl.arange(0, BLOCK_H)
offs_q_nope = cur_batch * stride_q_nope_bs + cur_head[:, None] * stride_q_nope_h + offs_d_ckv[
None, :]
offs_q_nope = cur_batch * stride_q_nope_bs + cur_head[:, None] * stride_q_nope_h + offs_d_ckv[None, :]
q_nope = tl.load(Q_nope + offs_q_nope)
offs_d_kpe = tl.arange(0, HEAD_DIM_KPE)
......@@ -224,9 +214,7 @@ def _mla_attn_kernel(
e_sum = e_sum * re_scale + tl.sum(p, 1)
e_max = n_e_max
offs_o = cur_batch * stride_o_b + cur_head[:,
None] * stride_o_h + split_kv_id * stride_o_s + offs_d_ckv[
None, :]
offs_o = cur_batch * stride_o_b + cur_head[:, None] * stride_o_h + split_kv_id * stride_o_s + offs_d_ckv[None, :]
tl.store(O + offs_o, acc / e_sum[:, None])
offs_o_1 = cur_batch * stride_o_b + cur_head * stride_o_h + split_kv_id * stride_o_s + HEAD_DIM_CKV
tl.store(O + offs_o_1, e_max + tl.log(e_sum))
......@@ -393,24 +381,30 @@ def mla_decode_triton(
@torch.inference_mode()
def run_flash_mla_triton(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q,
cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
def run_flash_mla_triton(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
blocked_v = blocked_k[..., :dv]
assert d > dv, "mla with rope dim should be larger than no rope dim"
q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous()
blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[...,
dv:].contiguous()
blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous()
def flash_mla_triton():
num_kv_splits = 32
o = torch.empty([b * s_q, h_q, dv])
attn_logits = torch.empty([b * s_q, h_q, num_kv_splits, dv + 1])
mla_decode_triton(
q_nope.view(-1, h_q, dv), q_pe.view(-1, h_q, d - dv), blocked_k_nope.view(-1, dv),
blocked_k_pe.view(-1, d - dv), o, block_table, cache_seqlens, attn_logits,
num_kv_splits, 1 / math.sqrt(d), block_size)
q_nope.view(-1, h_q, dv),
q_pe.view(-1, h_q, d - dv),
blocked_k_nope.view(-1, dv),
blocked_k_pe.view(-1, d - dv),
o,
block_table,
cache_seqlens,
attn_logits,
num_kv_splits,
1 / math.sqrt(d),
block_size,
)
return o.view([b, s_q, h_q, dv])
out_flash = flash_mla_triton()
......@@ -419,13 +413,10 @@ def run_flash_mla_triton(q, block_table, blocked_k, max_seqlen_pad, block_size,
@torch.inference_mode()
def run_flash_mla_tilelang(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q,
cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
def run_flash_mla_tilelang(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
assert d > dv, "mla with rope dim should be larger than no rope dim"
q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous()
blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[...,
dv:].contiguous()
blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous()
dpe = d - dv
num_kv_splits = 1
......@@ -434,8 +425,7 @@ def run_flash_mla_tilelang(q, block_table, blocked_k, max_seqlen_pad, block_size
out_partial = torch.empty(b, h_q, num_kv_splits, dv, dtype=dtype, device=q.device)
glse = torch.empty(b, h_q, num_kv_splits, dtype=dtype, device=q.device)
kernel = mla_decode_tilelang(b, h_q, h_kv, max_seqlen_pad, dv, dpe, BLOCK_N, BLOCK_H,
num_kv_splits, block_size)
kernel = mla_decode_tilelang(b, h_q, h_kv, max_seqlen_pad, dv, dpe, BLOCK_N, BLOCK_H, num_kv_splits, block_size)
def flash_mla_tilelang():
out = kernel(
......@@ -486,38 +476,31 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal
q = torch.randn(b, s_q, h_q, d)
block_size = 64
block_table = torch.arange(
b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size)
block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size)
blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d)
out_a, lse_a, perf_a = baseline_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b,
s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype)
out_b, lse_b, perf_b = target_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b,
s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype)
out_a, lse_a, perf_a = baseline_func(
q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype
)
out_b, lse_b, perf_b = target_func(
q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype
)
torch.testing.assert_close(out_b.float(), out_a.float(), atol=1e-2, rtol=1e-2), "out"
if target not in ["flashinfer", "flash_mla_triton", "tilelang"
] and baseline not in ["flashinfer", "flash_mla_triton", "tilelang"]:
if target not in ["flashinfer", "flash_mla_triton", "tilelang"] and baseline not in ["flashinfer", "flash_mla_triton", "tilelang"]:
# flashinfer has a different lse return value
# flash_mla_triton and flash_mla_tilelang doesn't return lse
torch.testing.assert_close(lse_b.float(), lse_a.float(), atol=1e-2, rtol=1e-2), "lse"
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (
torch.finfo(dtype).bits // 8)
print(
f"perf {baseline}: {perf_a:.3f} ms, {FLOPS / 10 ** 9 / perf_a:.0f} TFLOPS, {bytes / 10 ** 6 / perf_a:.0f} GB/s"
)
print(
f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10 ** 9 / perf_b:.0f} TFLOPS, {bytes / 10 ** 6 / perf_b:.0f} GB/s"
)
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8)
print(f"perf {baseline}: {perf_a:.3f} ms, {FLOPS / 10**9 / perf_a:.0f} TFLOPS, {bytes / 10**6 / perf_a:.0f} GB/s")
print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10**9 / perf_b:.0f} TFLOPS, {bytes / 10**6 / perf_b:.0f} GB/s")
return bytes / 10**6 / perf_a, bytes / 10**6 / perf_b
def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
print(
f"{target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}"
)
print(f"{target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}")
torch.set_default_dtype(dtype)
device = torch.device("cuda:0")
torch.set_default_device(device)
......@@ -534,19 +517,16 @@ def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
q = torch.randn(b, s_q, h_q, d)
block_size = 64
block_table = torch.arange(
b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size)
block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size)
blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d)
out_b, lse_b, perf_b = target_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b,
s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype)
out_b, lse_b, perf_b = target_func(
q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype
)
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (
torch.finfo(dtype).bits // 8)
print(
f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10 ** 9 / perf_b:.0f} TFLOPS, {bytes / 10 ** 6 / perf_b:.0f} GB/s"
)
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8)
print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10**9 / perf_b:.0f} TFLOPS, {bytes / 10**6 / perf_b:.0f} GB/s")
return bytes / 10**6 / perf_b
......@@ -558,26 +538,22 @@ available_targets = [
"flash_mla_triton",
]
shape_configs = [{
"b":
batch,
"s_q":
1,
"cache_seqlens":
torch.tensor([seqlen + 2 * i for i in range(batch)], dtype=torch.int32, device="cuda"),
"h_q":
head,
"h_kv":
1,
"d":
512 + 64,
"dv":
512,
"causal":
True,
"dtype":
torch.float16
} for batch in [128] for seqlen in [1024, 2048, 4096, 8192, 16384, 32768] for head in [128]]
shape_configs = [
{
"b": batch,
"s_q": 1,
"cache_seqlens": torch.tensor([seqlen + 2 * i for i in range(batch)], dtype=torch.int32, device="cuda"),
"h_q": head,
"h_kv": 1,
"d": 512 + 64,
"dv": 512,
"causal": True,
"dtype": torch.float16,
}
for batch in [128]
for seqlen in [1024, 2048, 4096, 8192, 16384, 32768]
for head in [128]
]
def get_args():
......@@ -599,26 +575,54 @@ if __name__ == "__main__":
for shape in shape_configs:
if args.all:
for target in available_targets:
perf = compare_a(target, shape["b"], shape["s_q"], shape["cache_seqlens"],
shape["h_q"], shape["h_kv"], shape["d"], shape["dv"],
shape["causal"], shape["dtype"])
perf = compare_a(
target,
shape["b"],
shape["s_q"],
shape["cache_seqlens"],
shape["h_q"],
shape["h_kv"],
shape["d"],
shape["dv"],
shape["causal"],
shape["dtype"],
)
fout.write(
f'{target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n'
f"{target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perf:.0f}\n"
)
elif args.compare:
perfa, prefb = compare_ab(args.baseline, args.target, shape["b"], shape["s_q"],
shape["cache_seqlens"], shape["h_q"], shape["h_kv"],
shape["d"], shape["dv"], shape["causal"], shape["dtype"])
perfa, prefb = compare_ab(
args.baseline,
args.target,
shape["b"],
shape["s_q"],
shape["cache_seqlens"],
shape["h_q"],
shape["h_kv"],
shape["d"],
shape["dv"],
shape["causal"],
shape["dtype"],
)
fout.write(
f'{args.baseline},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perfa:.0f}\n'
f"{args.baseline},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perfa:.0f}\n"
)
fout.write(
f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{prefb:.0f}\n'
f"{args.target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{prefb:.0f}\n"
)
elif args.one:
perf = compare_a(args.target, shape["b"], shape["s_q"], shape["cache_seqlens"],
shape["h_q"], shape["h_kv"], shape["d"], shape["dv"],
shape["causal"], shape["dtype"])
perf = compare_a(
args.target,
shape["b"],
shape["s_q"],
shape["cache_seqlens"],
shape["h_q"],
shape["h_kv"],
shape["d"],
shape["dv"],
shape["causal"],
shape["dtype"],
)
fout.write(
f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n'
f"{args.target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perf:.0f}\n"
)
......@@ -8,11 +8,12 @@ import argparse
@tilelang.jit(
out_idx=[6], pass_configs={
out_idx=[6],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split,
softmax_scale):
},
)
def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split, softmax_scale):
scale = float(softmax_scale * 1.44269504) # log2(e)
dtype = "float16"
accum_dtype = "float"
......@@ -22,11 +23,11 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
@T.macro
def flash_attn(
Q: T.Tensor([batch, heads, dim], dtype),
Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype),
Q: T.Tensor([batch, heads, dim], dtype),
Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype),
):
with T.Kernel(heads // min(block_H, kv_group_num), batch, threads=256) as (hid, bid):
Q_shared = T.alloc_shared([block_H, dim], dtype)
......@@ -44,33 +45,24 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
logsum = T.alloc_fragment([block_H], accum_dtype)
cur_kv_head = hid // (kv_group_num // block_H)
T.annotate_layout({
O_shared: tilelang.layout.make_swizzled_layout(O_shared),
})
T.annotate_layout(
{
O_shared: tilelang.layout.make_swizzled_layout(O_shared),
}
)
T.copy(Q[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_shared)
T.copy(Q_pe[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_pe_shared)
T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_shared)
T.copy(Q_pe[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_pe_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = T.ceildiv(seqlen_kv, block_N)
for k in T.Pipelined(loop_range, num_stages=2):
T.copy(KV[bid, k * block_N:(k + 1) * block_N, cur_kv_head, :], KV_shared)
T.copy(K_pe[bid, k * block_N:(k + 1) * block_N, cur_kv_head, :], K_pe_shared)
T.gemm(
Q_shared,
KV_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullCol,
clear_accum=True)
T.gemm(
Q_pe_shared,
K_pe_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullCol)
T.copy(KV[bid, k * block_N : (k + 1) * block_N, cur_kv_head, :], KV_shared)
T.copy(K_pe[bid, k * block_N : (k + 1) * block_N, cur_kv_head, :], K_pe_shared)
T.gemm(Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol, clear_accum=True)
T.gemm(Q_pe_shared, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
......@@ -90,20 +82,18 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :])
T.copy(O_shared, Output[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :])
@T.macro
def flash_attn_split(
Q: T.Tensor([batch, heads, dim], dtype),
Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
Q: T.Tensor([batch, heads, dim], dtype),
Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
):
with T.Kernel(
batch, heads // min(block_H, kv_group_num), num_split,
threads=256) as (bid, hid, bz):
with T.Kernel(batch, heads // min(block_H, kv_group_num), num_split, threads=256) as (bid, hid, bz):
Q_shared = T.alloc_shared([block_H, dim], dtype)
S_shared = T.alloc_shared([block_H, block_N], dtype)
Q_pe_shared = T.alloc_shared([block_H, pe_dim], dtype)
......@@ -121,13 +111,15 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
cur_kv_head = hid // (kv_group_num // block_H)
T.use_swizzle(10)
T.annotate_layout({
O_shared: tilelang.layout.make_swizzled_layout(O_shared),
S_shared: tilelang.layout.make_swizzled_layout(S_shared),
})
T.annotate_layout(
{
O_shared: tilelang.layout.make_swizzled_layout(O_shared),
S_shared: tilelang.layout.make_swizzled_layout(S_shared),
}
)
T.copy(Q[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_shared)
T.copy(Q_pe[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_pe_shared)
T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_shared)
T.copy(Q_pe[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_pe_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
......@@ -139,14 +131,8 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T.copy(KV[bid, kv_start:kv_end, cur_kv_head, :], KV_shared)
T.copy(K_pe[bid, kv_start:kv_end, cur_kv_head, :], K_pe_shared)
T.clear(acc_s)
T.gemm(
Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
T.gemm(
Q_pe_shared,
K_pe_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullCol)
T.gemm(Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
T.gemm(Q_pe_shared, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
......@@ -168,16 +154,15 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
acc_o[i, j] /= logsum[i]
for i in T.Parallel(block_H):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(logsum, glse[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, bz])
T.copy(logsum, glse[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, bz])
T.copy(acc_o, O_shared)
T.copy(O_shared, Output_partial[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H,
bz, :])
T.copy(O_shared, Output_partial[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, bz, :])
@T.macro
def combine(
glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype),
glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype),
):
with T.Kernel(heads, batch, threads=128) as (hid, bz):
po_local = T.alloc_fragment([dim], dtype)
......@@ -187,9 +172,11 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
lse_max_local = T.alloc_local([1], accum_dtype)
scale_local = T.alloc_local([1], accum_dtype)
T.annotate_layout({
lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i),
})
T.annotate_layout(
{
lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i),
}
)
T.clear(lse_logsum_local)
T.clear(o_accum_local)
......@@ -212,26 +199,26 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
@T.prim_func
def main_split(
Q: T.Tensor([batch, heads, dim], dtype),
Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype),
Q: T.Tensor([batch, heads, dim], dtype),
Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype),
):
flash_attn_split(Q, Q_pe, KV, K_pe, glse, Output_partial)
combine(glse, Output_partial, Output)
@T.prim_func
def main_no_split(
Q: T.Tensor([batch, heads, dim], dtype),
Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype),
Q: T.Tensor([batch, heads, dim], dtype),
Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype),
):
flash_attn(Q, Q_pe, KV, K_pe, Output)
......@@ -256,31 +243,24 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial):
dim = q.shape[-1]
pe_dim = q_pe.shape[-1]
num_head_groups = q.shape[1] // kv.shape[2]
scale = (dim + pe_dim)**0.5
q = rearrange(
q, 'b (h g) d -> b g h d', g=num_head_groups) # [batch_size, num_head_groups, groups, dim]
scale = (dim + pe_dim) ** 0.5
q = rearrange(q, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, dim]
q_pe = rearrange(
q_pe, 'b (h g) d -> b g h d',
g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim]
q_pe = rearrange(q_pe, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim]
kv = rearrange(kv, 'b n h d -> b h n d') # [batch_size, groups, seqlen_kv, dim]
kv = rearrange(kv, "b n h d -> b h n d") # [batch_size, groups, seqlen_kv, dim]
k_pe = rearrange(k_pe, 'b n h d -> b h n d') # [batch_size, num_head_groups, groups, pe_dim]
k_pe = rearrange(k_pe, "b n h d -> b h n d") # [batch_size, num_head_groups, groups, pe_dim]
query = torch.concat([q, q_pe], dim=-1)
key = torch.concat([kv, k_pe], dim=-1)
scores = einsum(
query, key,
'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, groups, seqlen_kv]
scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, groups, seqlen_kv]
attention = F.softmax(
scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv]
attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv]
out = einsum(attention, kv,
'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, groups, dim]
out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim]
out = einsum(attention, kv, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, groups, dim]
out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim]
return out
......@@ -298,10 +278,9 @@ def main(
BLOCK_N = 64
BLOCK_H = min(64, heads // kv_heads)
num_split = 1
softmax_scale = (dim + pe_dim)**-0.5
softmax_scale = (dim + pe_dim) ** -0.5
kernel = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split,
softmax_scale)
kernel = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split, softmax_scale)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
profiler.assert_allclose(ref_program, rtol=1e-4, atol=1e-4)
latency = profiler.do_bench(warmup=500)
......@@ -311,12 +290,12 @@ def main(
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=132, help='batch size')
parser.add_argument('--heads', type=int, default=128, help='q heads number')
parser.add_argument('--kv_heads', type=int, default=1, help='kv heads number')
parser.add_argument('--kv_ctx', type=int, default=8192, help='kv context length')
parser.add_argument('--dim', type=int, default=512, help='head dim')
parser.add_argument('--pe_dim', type=int, default=64, help='pe head dim')
parser.add_argument("--batch", type=int, default=132, help="batch size")
parser.add_argument("--heads", type=int, default=128, help="q heads number")
parser.add_argument("--kv_heads", type=int, default=1, help="kv heads number")
parser.add_argument("--kv_ctx", type=int, default=8192, help="kv context length")
parser.add_argument("--dim", type=int, default=512, help="head dim")
parser.add_argument("--pe_dim", type=int, default=64, help="pe head dim")
args = parser.parse_args()
batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim
main(batch, heads, kv_heads, kv_ctx, dim, pe_dim)
......@@ -8,22 +8,14 @@ import math
@tilelang.jit(
out_idx=[8], pass_configs={
out_idx=[8],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def mla_decode_tilelang(batch,
h_q,
h_kv,
max_seqlen_pad,
dv,
dpe,
block_N,
block_H,
num_split,
block_size,
softmax_scale=None):
},
)
def mla_decode_tilelang(batch, h_q, h_kv, max_seqlen_pad, dv, dpe, block_N, block_H, num_split, block_size, softmax_scale=None):
if softmax_scale is None:
softmax_scale = (dv + dpe)**-0.5
softmax_scale = (dv + dpe) ** -0.5
scale = float(softmax_scale * 1.44269504) # log2(e)
dtype = "float16"
accum_dtype = "float"
......@@ -34,13 +26,13 @@ def mla_decode_tilelang(batch,
@T.macro
def flash_mla_kernel(
Q: T.Tensor([batch, h_q, dv], dtype),
Q_pe: T.Tensor([batch, h_q, dpe], dtype),
KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype),
K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype),
BLOCK_TABLE: T.Tensor([batch, max_seqlen_pad // block_size], "int32"),
CACHE_SEQLENS: T.Tensor([batch], "int32"),
Output: T.Tensor([batch, h_q, dv], dtype),
Q: T.Tensor([batch, h_q, dv], dtype),
Q_pe: T.Tensor([batch, h_q, dpe], dtype),
KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype),
K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype),
BLOCK_TABLE: T.Tensor([batch, max_seqlen_pad // block_size], "int32"),
CACHE_SEQLENS: T.Tensor([batch], "int32"),
Output: T.Tensor([batch, h_q, dv], dtype),
):
with T.Kernel(batch, h_q // min(block_H, kv_group_num), threads=256) as (bx, by):
Q_shared = T.alloc_shared([block_H, dv], dtype)
......@@ -59,13 +51,15 @@ def mla_decode_tilelang(batch,
cur_kv_head = by // (kv_group_num // block_H)
T.use_swizzle(10)
T.annotate_layout({
O_shared: tilelang.layout.make_swizzled_layout(O_shared),
S_shared: tilelang.layout.make_swizzled_layout(S_shared),
})
T.annotate_layout(
{
O_shared: tilelang.layout.make_swizzled_layout(O_shared),
S_shared: tilelang.layout.make_swizzled_layout(S_shared),
}
)
T.copy(Q[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_shared)
T.copy(Q_pe[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_pe_shared)
T.copy(Q[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_shared)
T.copy(Q_pe[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_pe_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
......@@ -73,25 +67,17 @@ def mla_decode_tilelang(batch,
loop_range = T.ceildiv(CACHE_SEQLENS[bx], block_N)
for kr in T.Pipelined(loop_range, num_stages=2):
k = loop_range - 1 - kr
kv_start = BLOCK_TABLE[bx, (k * block_N) //
block_size] * block_size + (k * block_N) % block_size
T.copy(KV[kv_start:kv_start + block_N, cur_kv_head, :], KV_shared)
T.copy(K_pe[kv_start:kv_start + block_N, cur_kv_head, :], K_pe_shared)
kv_start = BLOCK_TABLE[bx, (k * block_N) // block_size] * block_size + (k * block_N) % block_size
T.copy(KV[kv_start : kv_start + block_N, cur_kv_head, :], KV_shared)
T.copy(K_pe[kv_start : kv_start + block_N, cur_kv_head, :], K_pe_shared)
T.clear(acc_s)
T.gemm(
Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
T.gemm(
Q_pe_shared,
K_pe_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullCol)
T.gemm(Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
T.gemm(Q_pe_shared, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
if kr == 0:
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.if_then_else(k * block_N + j >= CACHE_SEQLENS[bx],
-T.infinity(accum_dtype), acc_s[i, j])
acc_s[i, j] = T.if_then_else(k * block_N + j >= CACHE_SEQLENS[bx], -T.infinity(accum_dtype), acc_s[i, j])
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
......@@ -109,21 +95,20 @@ def mla_decode_tilelang(batch,
for i, j in T.Parallel(block_H, dv):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :])
T.copy(O_shared, Output[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :])
@T.macro
def flash_mla_split_kv_kernel(
Q: T.Tensor([batch, h_q, dv], dtype),
Q_pe: T.Tensor([batch, h_q, dpe], dtype),
KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype),
K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype),
BLOCK_TABLE: T.Tensor([batch, max_seqlen_pad // block_size], "int32"),
CACHE_SEQLENS: T.Tensor([batch], "int32"),
glse: T.Tensor([batch, h_q, num_split], dtype),
Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype),
Q: T.Tensor([batch, h_q, dv], dtype),
Q_pe: T.Tensor([batch, h_q, dpe], dtype),
KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype),
K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype),
BLOCK_TABLE: T.Tensor([batch, max_seqlen_pad // block_size], "int32"),
CACHE_SEQLENS: T.Tensor([batch], "int32"),
glse: T.Tensor([batch, h_q, num_split], dtype),
Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype),
):
with T.Kernel(
batch, h_q // min(block_H, kv_group_num), num_split, threads=256) as (bx, by, bz):
with T.Kernel(batch, h_q // min(block_H, kv_group_num), num_split, threads=256) as (bx, by, bz):
Q_shared = T.alloc_shared([block_H, dv], dtype)
S_shared = T.alloc_shared([block_H, block_N], dtype)
Q_pe_shared = T.alloc_shared([block_H, dpe], dtype)
......@@ -141,13 +126,15 @@ def mla_decode_tilelang(batch,
cur_kv_head = by // (kv_group_num // block_H)
T.use_swizzle(10)
T.annotate_layout({
O_shared: tilelang.layout.make_swizzled_layout(O_shared),
S_shared: tilelang.layout.make_swizzled_layout(S_shared),
})
T.annotate_layout(
{
O_shared: tilelang.layout.make_swizzled_layout(O_shared),
S_shared: tilelang.layout.make_swizzled_layout(S_shared),
}
)
T.copy(Q[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_shared)
T.copy(Q_pe[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_pe_shared)
T.copy(Q[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_shared)
T.copy(Q_pe[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_pe_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
......@@ -155,28 +142,20 @@ def mla_decode_tilelang(batch,
total_blocks = T.ceildiv(CACHE_SEQLENS[bx], block_N)
blocks_per_split = T.floordiv(total_blocks, num_split)
remaining_blocks = T.floormod(total_blocks, num_split)
loop_range = (blocks_per_split + T.if_then_else(bz < remaining_blocks, 1, 0))
loop_range = blocks_per_split + T.if_then_else(bz < remaining_blocks, 1, 0)
start = (blocks_per_split * bz + T.min(bz, remaining_blocks)) * block_N
for k in T.Pipelined(loop_range, num_stages=2):
kv_start = BLOCK_TABLE[bx, (start + k * block_N) //
block_size] * block_size + (k * block_N) % block_size
T.copy(KV[kv_start:kv_start + block_N, cur_kv_head, :], KV_shared)
T.copy(K_pe[kv_start:kv_start + block_N, cur_kv_head, :], K_pe_shared)
kv_start = BLOCK_TABLE[bx, (start + k * block_N) // block_size] * block_size + (k * block_N) % block_size
T.copy(KV[kv_start : kv_start + block_N, cur_kv_head, :], KV_shared)
T.copy(K_pe[kv_start : kv_start + block_N, cur_kv_head, :], K_pe_shared)
T.clear(acc_s)
T.gemm(
Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
T.gemm(
Q_pe_shared,
K_pe_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullCol)
T.gemm(Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
T.gemm(Q_pe_shared, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.if_then_else(start + k * block_N + j >= CACHE_SEQLENS[bx],
-T.infinity(accum_dtype), acc_s[i, j])
acc_s[i, j] = T.if_then_else(start + k * block_N + j >= CACHE_SEQLENS[bx], -T.infinity(accum_dtype), acc_s[i, j])
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
......@@ -196,15 +175,15 @@ def mla_decode_tilelang(batch,
acc_o[i, j] /= logsum[i]
for i in T.Parallel(block_H):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(logsum, glse[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, bz])
T.copy(logsum, glse[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, bz])
T.copy(acc_o, O_shared)
T.copy(O_shared, Output_partial[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, bz, :])
T.copy(O_shared, Output_partial[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, bz, :])
@T.macro
def combine(
glse: T.Tensor([batch, h_q, num_split], dtype),
Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype),
Output: T.Tensor([batch, h_q, dv], dtype),
glse: T.Tensor([batch, h_q, num_split], dtype),
Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype),
Output: T.Tensor([batch, h_q, dv], dtype),
):
with T.Kernel(h_q, batch, threads=128) as (by, bz):
po_local = T.alloc_fragment([dv], dtype)
......@@ -214,9 +193,11 @@ def mla_decode_tilelang(batch,
lse_max_local = T.alloc_local([1], accum_dtype)
scale_local = T.alloc_local([1], accum_dtype)
T.annotate_layout({
lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i),
})
T.annotate_layout(
{
lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i),
}
)
T.clear(lse_logsum_local)
T.clear(o_accum_local)
......@@ -239,31 +220,30 @@ def mla_decode_tilelang(batch,
@T.prim_func
def main_split(
Q: T.Tensor([batch, h_q, dv], dtype),
Q_pe: T.Tensor([batch, h_q, dpe], dtype),
KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype),
K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype),
block_table: T.Tensor([batch, max_seqlen_pad // block_size], "int32"),
cache_seqlens: T.Tensor([batch], "int32"),
glse: T.Tensor([batch, h_q, num_split], dtype),
Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype),
Output: T.Tensor([batch, h_q, dv], dtype),
Q: T.Tensor([batch, h_q, dv], dtype),
Q_pe: T.Tensor([batch, h_q, dpe], dtype),
KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype),
K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype),
block_table: T.Tensor([batch, max_seqlen_pad // block_size], "int32"),
cache_seqlens: T.Tensor([batch], "int32"),
glse: T.Tensor([batch, h_q, num_split], dtype),
Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype),
Output: T.Tensor([batch, h_q, dv], dtype),
):
flash_mla_split_kv_kernel(Q, Q_pe, KV, K_pe, block_table, cache_seqlens, glse,
Output_partial)
flash_mla_split_kv_kernel(Q, Q_pe, KV, K_pe, block_table, cache_seqlens, glse, Output_partial)
combine(glse, Output_partial, Output)
@T.prim_func
def main_no_split(
Q: T.Tensor([batch, h_q, dv], dtype),
Q_pe: T.Tensor([batch, h_q, dpe], dtype),
KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype),
K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype),
block_table: T.Tensor([batch, max_seqlen_pad // block_size], "int32"),
cache_seqlens: T.Tensor([batch], "int32"),
glse: T.Tensor([batch, h_q, num_split], dtype),
Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype),
Output: T.Tensor([batch, h_q, dv], dtype),
Q: T.Tensor([batch, h_q, dv], dtype),
Q_pe: T.Tensor([batch, h_q, dpe], dtype),
KV: T.Tensor([batch * max_seqlen_pad, h_kv, dv], dtype),
K_pe: T.Tensor([batch * max_seqlen_pad, h_kv, dpe], dtype),
block_table: T.Tensor([batch, max_seqlen_pad // block_size], "int32"),
cache_seqlens: T.Tensor([batch], "int32"),
glse: T.Tensor([batch, h_q, num_split], dtype),
Output_partial: T.Tensor([batch, h_q, num_split, dv], dtype),
Output: T.Tensor([batch, h_q, dv], dtype),
):
flash_mla_kernel(Q, Q_pe, KV, K_pe, block_table, cache_seqlens, Output)
......@@ -284,8 +264,7 @@ def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
s_q = query.shape[-2]
s_k = key.shape[-2]
attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype, device=query.device)
temp_mask = torch.ones(
s_q, s_k, dtype=torch.bool, device=query.device).tril(diagonal=s_k - s_q)
temp_mask = torch.ones(s_q, s_k, dtype=torch.bool, device=query.device).tril(diagonal=s_k - s_q)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(query.dtype)
attn_weight += attn_bias
......@@ -295,8 +274,7 @@ def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
@torch.inference_mode()
def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q,
h_kv, d, dv, causal, dtype):
def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
# q: [b, s_q, h_q, d]
# block_table: [b, max_seqlen_pad // block_size]
# blocked_k: [b * max_seqlen_pad // block_size, block_size, h_kv, d]
......@@ -325,13 +303,10 @@ def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q,
return out_torch
def run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens,
h_q, h_kv, d, dv, causal, dtype):
def run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
assert d > dv, "mla with rope dim should be larger than no rope dim"
q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous()
blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[...,
dv:].contiguous()
blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous()
dpe = d - dv
num_kv_splits = 1
......@@ -341,8 +316,7 @@ def run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s
out_partial = torch.empty(b, h_q, num_kv_splits, dv, dtype=dtype, device=q.device)
glse = torch.empty(b, h_q, num_kv_splits, dtype=dtype, device=q.device)
kernel = mla_decode_tilelang(b, h_q, h_kv, max_seqlen_pad, dv, dpe, BLOCK_N, BLOCK_H,
num_kv_splits, block_size, softmax_scale)
kernel = mla_decode_tilelang(b, h_q, h_kv, max_seqlen_pad, dv, dpe, BLOCK_N, BLOCK_H, num_kv_splits, block_size, softmax_scale)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
def flash_mla_tilelang():
......@@ -360,8 +334,7 @@ def run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s
out_flash = flash_mla_tilelang()
t = do_bench(flash_mla_tilelang)
out_ref = run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q,
cache_seqlens, h_q, h_kv, d, dv, causal, dtype)
out_ref = run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype)
torch.testing.assert_close(out_flash, out_ref, rtol=0.01, atol=0.01)
print("All close")
return out_flash, t
......@@ -369,12 +342,12 @@ def run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=128, help='batch size')
parser.add_argument('--h_q', type=int, default=128, help='q heads number')
parser.add_argument('--h_kv', type=int, default=1, help='kv heads number')
parser.add_argument('--cache_seqlen', type=int, default=8192, help='kv cache context length')
parser.add_argument('--d', type=int, default=576, help='query/key head dim, d = dv + dpe')
parser.add_argument('--dv', type=int, default=512, help='value head dim')
parser.add_argument("--batch", type=int, default=128, help="batch size")
parser.add_argument("--h_q", type=int, default=128, help="q heads number")
parser.add_argument("--h_kv", type=int, default=1, help="kv heads number")
parser.add_argument("--cache_seqlen", type=int, default=8192, help="kv cache context length")
parser.add_argument("--d", type=int, default=576, help="query/key head dim, d = dv + dpe")
parser.add_argument("--dv", type=int, default=512, help="value head dim")
args = parser.parse_args()
b, h_q, h_kv, cache_seqlen, d, dv = args.batch, args.h_q, args.h_kv, args.cache_seqlen, args.d, args.dv
......@@ -383,9 +356,7 @@ if __name__ == "__main__":
s_q = 1 # for decode, s_q = 1
block_size = 64
cache_seqlens = torch.tensor([cache_seqlen + 2 * i for i in range(b)],
dtype=torch.int32,
device=device)
cache_seqlens = torch.tensor([cache_seqlen + 2 * i for i in range(b)], dtype=torch.int32, device=device)
dpe = d - dv
causal = True
......@@ -397,12 +368,11 @@ if __name__ == "__main__":
total_flops = s_q * total_seqlens * h_q * d * 2
q = torch.randn(b, s_q, h_q, d, dtype=dtype, device=device)
block_table = torch.arange(
b * max_seqlen_pad // block_size, dtype=torch.int32,
device=device).view(b, max_seqlen_pad // block_size)
block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32, device=device).view(b, max_seqlen_pad // block_size)
blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d, dtype=dtype, device=device)
out_flash, latency = run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b,
s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype)
out_flash, latency = run_tilelang_mla(
q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype
)
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
......@@ -9,11 +9,13 @@ import argparse
@tilelang.jit(
out_idx=[6], pass_configs={
out_idx=[6],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
},
)
def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split):
scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e)
scale = (1.0 / (dim + pe_dim)) ** 0.5 * 1.44269504 # log2(e)
dtype = "float16"
accum_dtype = "float"
kv_group_num = heads // kv_head_num
......@@ -23,13 +25,13 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
@T.prim_func
def main_split_persistent(
Q: T.Tensor([batch, heads, dim], dtype),
Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype),
Q: T.Tensor([batch, heads, dim], dtype),
Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype),
):
with T.Kernel(sm_num, threads=256) as (block_id):
Q_shared = T.alloc_shared([block_H, dim], dtype)
......@@ -53,11 +55,13 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
lse_max_local = T.alloc_local([1], accum_dtype)
scale_local = T.alloc_local([1], accum_dtype)
T.annotate_layout({
# O_shared: tilelang.layout.make_swizzled_layout(O_shared),
S_shared: tilelang.layout.make_swizzled_layout(S_shared),
lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i),
})
T.annotate_layout(
{
# O_shared: tilelang.layout.make_swizzled_layout(O_shared),
S_shared: tilelang.layout.make_swizzled_layout(S_shared),
lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i),
}
)
T.use_swizzle(10)
total_tiles = batch * (heads // min(block_H, kv_group_num)) * num_split
......@@ -70,8 +74,8 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
cur_kv_head = hid // (kv_group_num // block_H)
if bid < batch and hid * VALID_BLOCK_H < heads and sid < num_split:
T.copy(Q[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_shared)
T.copy(Q_pe[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_pe_shared)
T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_shared)
T.copy(Q_pe[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_pe_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
......@@ -83,26 +87,15 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T.copy(KV[bid, kv_start:kv_end, cur_kv_head, :], KV_shared)
T.copy(K_pe[bid, kv_start:kv_end, cur_kv_head, :], K_pe_shared)
T.clear(acc_s)
T.gemm(
Q_shared,
KV_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullCol)
T.gemm(
Q_pe_shared,
K_pe_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullCol)
T.gemm(Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
T.gemm(Q_pe_shared, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
for i in T.Parallel(block_H):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale -
scores_max[i] * scale)
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
......@@ -117,11 +110,9 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
acc_o[i, j] /= logsum[i]
for i in T.Parallel(block_H):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(logsum, glse[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, sid])
T.copy(logsum, glse[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, sid])
# T.copy(acc_o, O_shared)
T.copy(
acc_o, Output_partial[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H,
sid, :])
T.copy(acc_o, Output_partial[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, sid, :])
T.sync_grid()
waves = T.ceildiv(heads * batch, sm_num)
......@@ -167,42 +158,35 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial):
dim = q.shape[-1]
pe_dim = q_pe.shape[-1]
num_head_groups = q.shape[1] // kv.shape[2]
scale = (dim + pe_dim)**0.5
q = rearrange(
q, 'b (h g) d -> b g h d', g=num_head_groups) # [batch_size, num_head_groups, groups, dim]
scale = (dim + pe_dim) ** 0.5
q = rearrange(q, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, dim]
q_pe = rearrange(
q_pe, 'b (h g) d -> b g h d',
g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim]
q_pe = rearrange(q_pe, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim]
kv = rearrange(kv, 'b n h d -> b h n d') # [batch_size, groups, seqlen_kv, dim]
kv = rearrange(kv, "b n h d -> b h n d") # [batch_size, groups, seqlen_kv, dim]
k_pe = rearrange(k_pe, 'b n h d -> b h n d') # [batch_size, num_head_groups, groups, pe_dim]
k_pe = rearrange(k_pe, "b n h d -> b h n d") # [batch_size, num_head_groups, groups, pe_dim]
query = torch.concat([q, q_pe], dim=-1)
key = torch.concat([kv, k_pe], dim=-1)
scores = einsum(
query, key,
'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, groups, seqlen_kv]
scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, groups, seqlen_kv]
attention = F.softmax(
scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv]
attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv]
out = einsum(attention, kv,
'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, groups, dim]
out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim]
out = einsum(attention, kv, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, groups, dim]
out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim]
return out
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=128, help='batch size')
parser.add_argument('--heads', type=int, default=128, help='q heads number')
parser.add_argument('--kv_heads', type=int, default=1, help='kv heads number')
parser.add_argument('--kv_ctx', type=int, default=8192, help='kv context length')
parser.add_argument('--dim', type=int, default=512, help='head dim')
parser.add_argument('--pe_dim', type=int, default=64, help='pe head dim')
parser.add_argument("--batch", type=int, default=128, help="batch size")
parser.add_argument("--heads", type=int, default=128, help="q heads number")
parser.add_argument("--kv_heads", type=int, default=1, help="kv heads number")
parser.add_argument("--kv_ctx", type=int, default=8192, help="kv context length")
parser.add_argument("--dim", type=int, default=512, help="head dim")
parser.add_argument("--pe_dim", type=int, default=64, help="pe head dim")
args = parser.parse_args()
batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim
qk_flops = 2 * batch * heads * kv_ctx * (dim + pe_dim)
......
......@@ -13,14 +13,19 @@ import argparse
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
compile_flags=[
"-O3", "-Wno-deprecated-declarations", "-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__", "-U__CUDA_NO_HALF2_OPERATORS__",
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__", "--expt-relaxed-constexpr", "--expt-extended-lambda",
"--ptxas-options=-v,--register-usage-level=10", "-DNDEBUG"
"-O3",
"-Wno-deprecated-declarations",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_HALF2_OPERATORS__",
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"--ptxas-options=-v,--register-usage-level=10",
"-DNDEBUG",
],
)
def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split,
softmax_scale):
def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split, softmax_scale):
sm_scale = float(softmax_scale * 1.44269504) # log2(e)
dtype = "float16"
accum_dtype = "float"
......@@ -30,11 +35,11 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
@T.macro
def flash_attn(
Q: T.Tensor([batch, heads, dim], dtype),
Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype),
Q: T.Tensor([batch, heads, dim], dtype),
Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype),
):
with T.Kernel(heads // min(block_H, kv_group_num), batch, threads=384) as (hid, bid):
Q_shared_l = T.alloc_shared([block_H, dim // 2], dtype)
......@@ -75,16 +80,16 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
tx = T.get_thread_binding()
T.copy(Q[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, 0:dim // 2], Q_shared_l)
T.copy(Q[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, dim // 2:dim], Q_shared_r)
T.copy(Q_pe[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_tail_shared)
T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, 0 : dim // 2], Q_shared_l)
T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, dim // 2 : dim], Q_shared_r)
T.copy(Q_pe[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_tail_shared)
T.barrier_arrive(bar_q)
if tx < 128:
T.set_max_nreg(240, 1)
T.fill(sumexp, 0)
T.fill(m_i, -2**30) # avoid -inf - inf to cause nan
T.fill(m_i, -(2**30)) # avoid -inf - inf to cause nan
T.fill(acc_o_l, 0)
T.barrier_wait(bar_q, 0)
......@@ -166,8 +171,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
for h_i in T.Parallel(block_H):
sumexp[h_i] = T.log2(sumexp[h_i]) + m_i[h_i] * sm_scale
T.copy(acc_o_l, O_shared_l)
T.copy(O_shared_l, Output[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H,
0:dim // 2])
T.copy(O_shared_l, Output[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, 0 : dim // 2])
elif tx >= 128 and tx < 256:
T.set_max_nreg(168, 1)
......@@ -197,8 +201,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
acc_o_r[h_i, d_i] /= sum_exp_shared[h_i]
T.copy(acc_o_r, O_shared_r)
T.copy(O_shared_r, Output[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H,
dim // 2:dim])
T.copy(O_shared_r, Output[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, dim // 2 : dim])
elif tx >= 256:
# producer
......@@ -211,19 +214,17 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
with T.attr("default", "async_scope", 1):
for u in T.serial(4):
for v in T.vectorized(8):
KV_shared_0_l[r * 16 + (tx - 256) // 8,
64 * u + (tx - 256) % 8 * 8 +
v] = KV[bid, kv_indices, cur_kv_head,
64 * u + (tx - 256) % 8 * 8 + v]
KV_shared_0_r[r * 16 + (tx - 256) // 8,
64 * u + (tx - 256) % 8 * 8 +
v] = KV[bid, kv_indices, cur_kv_head, dim // 2 +
64 * u + (tx - 256) % 8 * 8 + v]
KV_shared_0_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[
bid, kv_indices, cur_kv_head, 64 * u + (tx - 256) % 8 * 8 + v
]
KV_shared_0_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[
bid, kv_indices, cur_kv_head, dim // 2 + 64 * u + (tx - 256) % 8 * 8 + v
]
with T.attr("default", "async_scope", 1):
for v in T.vectorized(8):
K_tail_shared_0[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 +
v] = K_pe[bid, kv_indices, cur_kv_head,
(tx - 256) % 8 * 8 + v]
K_tail_shared_0[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v] = K_pe[
bid, kv_indices, cur_kv_head, (tx - 256) % 8 * 8 + v
]
T.cp_async_barrier_noinc(bar_k_0_ready[0])
# Buffer 1
......@@ -233,33 +234,29 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
with T.attr("default", "async_scope", 1):
for u in T.serial(4):
for v in T.vectorized(8):
KV_shared_1_l[r * 16 + (tx - 256) // 8,
64 * u + (tx - 256) % 8 * 8 +
v] = KV[bid, kv_indices, cur_kv_head,
64 * u + (tx - 256) % 8 * 8 + v]
KV_shared_1_r[r * 16 + (tx - 256) // 8,
64 * u + (tx - 256) % 8 * 8 +
v] = KV[bid, kv_indices, cur_kv_head, dim // 2 +
64 * u + (tx - 256) % 8 * 8 + v]
KV_shared_1_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[
bid, kv_indices, cur_kv_head, 64 * u + (tx - 256) % 8 * 8 + v
]
KV_shared_1_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[
bid, kv_indices, cur_kv_head, dim // 2 + 64 * u + (tx - 256) % 8 * 8 + v
]
with T.attr("default", "async_scope", 1):
for v in T.vectorized(8):
K_tail_shared_1[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 +
v] = K_pe[bid, kv_indices, cur_kv_head,
(tx - 256) % 8 * 8 + v]
K_tail_shared_1[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v] = K_pe[
bid, kv_indices, cur_kv_head, (tx - 256) % 8 * 8 + v
]
T.cp_async_barrier_noinc(bar_k_1_ready[0])
@T.macro
def flash_attn_split(
Q: T.Tensor([batch, heads, dim], dtype),
Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
Q: T.Tensor([batch, heads, dim], dtype),
Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
):
with T.Kernel(
batch, heads // min(block_H, kv_group_num), num_split,
threads=384) as (bid, hid, bz):
with T.Kernel(batch, heads // min(block_H, kv_group_num), num_split, threads=384) as (bid, hid, bz):
Q_shared_l = T.alloc_shared([block_H, dim // 2], dtype)
Q_shared_r = T.alloc_shared([block_H, dim // 2], dtype)
Q_tail_shared = T.alloc_shared([block_H, pe_dim], dtype)
......@@ -298,16 +295,16 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
tx = T.get_thread_binding()
T.copy(Q[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, 0:dim // 2], Q_shared_l)
T.copy(Q[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, dim // 2:dim], Q_shared_r)
T.copy(Q_pe[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_tail_shared)
T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, 0 : dim // 2], Q_shared_l)
T.copy(Q[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, dim // 2 : dim], Q_shared_r)
T.copy(Q_pe[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, :], Q_tail_shared)
T.barrier_arrive(bar_q)
if tx < 128:
T.set_max_nreg(240, 1)
T.fill(sumexp, 0)
T.fill(m_i, -2**30) # avoid -inf - inf to cause nan
T.fill(m_i, -(2**30)) # avoid -inf - inf to cause nan
T.fill(acc_o_l, 0)
T.barrier_wait(bar_q, 0)
......@@ -389,10 +386,8 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
for h_i in T.Parallel(block_H):
sumexp[h_i] = T.log2(sumexp[h_i]) + m_i[h_i] * sm_scale
T.copy(acc_o_l, O_shared_l)
T.copy(
O_shared_l, Output_partial[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H,
bz, 0:dim // 2])
T.copy(sumexp, glse[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, bz])
T.copy(O_shared_l, Output_partial[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, bz, 0 : dim // 2])
T.copy(sumexp, glse[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, bz])
elif tx >= 128 and tx < 256:
T.set_max_nreg(168, 1)
......@@ -422,9 +417,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
acc_o_r[h_i, d_i] /= sum_exp_shared[h_i]
T.copy(acc_o_r, O_shared_r)
T.copy(
O_shared_r, Output_partial[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H,
bz, dim // 2:dim])
T.copy(O_shared_r, Output_partial[bid, hid * VALID_BLOCK_H : (hid + 1) * VALID_BLOCK_H, bz, dim // 2 : dim])
elif tx >= 256:
# producer
......@@ -433,54 +426,48 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
# Buffer 0
T.barrier_wait(bar_k_0_free[0], ((i_i & 1) ^ 1))
for r in T.serial(4):
kv_indices = (seqlen_kv // num_split) * bz + (
i_i * 2) * block_N + r * 16 + (tx - 256) // 8
kv_indices = (seqlen_kv // num_split) * bz + (i_i * 2) * block_N + r * 16 + (tx - 256) // 8
with T.attr("default", "async_scope", 1):
for u in T.serial(4):
for v in T.vectorized(8):
KV_shared_0_l[r * 16 + (tx - 256) // 8,
64 * u + (tx - 256) % 8 * 8 +
v] = KV[bid, kv_indices, cur_kv_head,
64 * u + (tx - 256) % 8 * 8 + v]
KV_shared_0_r[r * 16 + (tx - 256) // 8,
64 * u + (tx - 256) % 8 * 8 +
v] = KV[bid, kv_indices, cur_kv_head, dim // 2 +
64 * u + (tx - 256) % 8 * 8 + v]
KV_shared_0_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[
bid, kv_indices, cur_kv_head, 64 * u + (tx - 256) % 8 * 8 + v
]
KV_shared_0_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[
bid, kv_indices, cur_kv_head, dim // 2 + 64 * u + (tx - 256) % 8 * 8 + v
]
with T.attr("default", "async_scope", 1):
for v in T.vectorized(8):
K_tail_shared_0[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 +
v] = K_pe[bid, kv_indices, cur_kv_head,
(tx - 256) % 8 * 8 + v]
K_tail_shared_0[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v] = K_pe[
bid, kv_indices, cur_kv_head, (tx - 256) % 8 * 8 + v
]
T.cp_async_barrier_noinc(bar_k_0_ready[0])
# Buffer 1
T.barrier_wait(bar_k_1_free[0], ((i_i & 1) ^ 1))
for r in T.serial(4):
kv_indices = (seqlen_kv // num_split) * bz + (
i_i * 2 + 1) * block_N + r * 16 + (tx - 256) // 8
kv_indices = (seqlen_kv // num_split) * bz + (i_i * 2 + 1) * block_N + r * 16 + (tx - 256) // 8
with T.attr("default", "async_scope", 1):
for u in T.serial(4):
for v in T.vectorized(8):
KV_shared_1_l[r * 16 + (tx - 256) // 8,
64 * u + (tx - 256) % 8 * 8 +
v] = KV[bid, kv_indices, cur_kv_head,
64 * u + (tx - 256) % 8 * 8 + v]
KV_shared_1_r[r * 16 + (tx - 256) // 8,
64 * u + (tx - 256) % 8 * 8 +
v] = KV[bid, kv_indices, cur_kv_head, dim // 2 +
64 * u + (tx - 256) % 8 * 8 + v]
KV_shared_1_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[
bid, kv_indices, cur_kv_head, 64 * u + (tx - 256) % 8 * 8 + v
]
KV_shared_1_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[
bid, kv_indices, cur_kv_head, dim // 2 + 64 * u + (tx - 256) % 8 * 8 + v
]
with T.attr("default", "async_scope", 1):
for v in T.vectorized(8):
K_tail_shared_1[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 +
v] = K_pe[bid, kv_indices, cur_kv_head,
(tx - 256) % 8 * 8 + v]
K_tail_shared_1[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v] = K_pe[
bid, kv_indices, cur_kv_head, (tx - 256) % 8 * 8 + v
]
T.cp_async_barrier_noinc(bar_k_1_ready[0])
@T.macro
def combine(
glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype),
glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype),
):
with T.Kernel(heads, batch, threads=128) as (hid, bz):
po_local = T.alloc_fragment([dim], dtype)
......@@ -490,9 +477,11 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
lse_max_local = T.alloc_local([1], accum_dtype)
scale_local = T.alloc_local([1], accum_dtype)
T.annotate_layout({
lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i),
})
T.annotate_layout(
{
lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i),
}
)
T.clear(lse_logsum_local)
T.clear(o_accum_local)
......@@ -515,26 +504,26 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
@T.prim_func
def main_split(
Q: T.Tensor([batch, heads, dim], dtype),
Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype),
Q: T.Tensor([batch, heads, dim], dtype),
Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype),
):
flash_attn_split(Q, Q_pe, KV, K_pe, glse, Output_partial)
combine(glse, Output_partial, Output)
@T.prim_func
def main_no_split(
Q: T.Tensor([batch, heads, dim], dtype),
Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype),
Q: T.Tensor([batch, heads, dim], dtype),
Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype),
):
flash_attn(Q, Q_pe, KV, K_pe, Output)
......@@ -559,31 +548,24 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial):
dim = q.shape[-1]
pe_dim = q_pe.shape[-1]
num_head_groups = q.shape[1] // kv.shape[2]
scale = (dim + pe_dim)**0.5
q = rearrange(
q, 'b (h g) d -> b g h d', g=num_head_groups) # [batch_size, num_head_groups, groups, dim]
scale = (dim + pe_dim) ** 0.5
q = rearrange(q, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, dim]
q_pe = rearrange(
q_pe, 'b (h g) d -> b g h d',
g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim]
q_pe = rearrange(q_pe, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim]
kv = rearrange(kv, 'b n h d -> b h n d') # [batch_size, groups, seqlen_kv, dim]
kv = rearrange(kv, "b n h d -> b h n d") # [batch_size, groups, seqlen_kv, dim]
k_pe = rearrange(k_pe, 'b n h d -> b h n d') # [batch_size, num_head_groups, groups, pe_dim]
k_pe = rearrange(k_pe, "b n h d -> b h n d") # [batch_size, num_head_groups, groups, pe_dim]
query = torch.concat([q, q_pe], dim=-1)
key = torch.concat([kv, k_pe], dim=-1)
scores = einsum(
query, key,
'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, groups, seqlen_kv]
scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, groups, seqlen_kv]
attention = F.softmax(
scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv]
attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv]
out = einsum(attention, kv,
'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, groups, dim]
out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim]
out = einsum(attention, kv, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, groups, dim]
out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim]
return out
......@@ -601,10 +583,9 @@ def main(
BLOCK_N = 64
BLOCK_H = min(64, heads // kv_heads)
num_split = 1
softmax_scale = (dim + pe_dim)**-0.5
softmax_scale = (dim + pe_dim) ** -0.5
kernel = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split,
softmax_scale)
kernel = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split, softmax_scale)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
profiler.assert_allclose(ref_program, rtol=1e-4, atol=1e-4)
latency = profiler.do_bench(warmup=500)
......@@ -614,12 +595,12 @@ def main(
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=132, help='batch size')
parser.add_argument('--heads', type=int, default=128, help='q heads number')
parser.add_argument('--kv_heads', type=int, default=1, help='kv heads number')
parser.add_argument('--kv_ctx', type=int, default=8192, help='kv context length')
parser.add_argument('--dim', type=int, default=512, help='head dim')
parser.add_argument('--pe_dim', type=int, default=64, help='pe head dim')
parser.add_argument("--batch", type=int, default=132, help="batch size")
parser.add_argument("--heads", type=int, default=128, help="q heads number")
parser.add_argument("--kv_heads", type=int, default=1, help="kv heads number")
parser.add_argument("--kv_ctx", type=int, default=8192, help="kv context length")
parser.add_argument("--dim", type=int, default=512, help="head dim")
parser.add_argument("--pe_dim", type=int, default=64, help="pe head dim")
args = parser.parse_args()
batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim
main(batch, heads, kv_heads, kv_ctx, dim, pe_dim)
......@@ -8,11 +8,13 @@ import argparse
@tilelang.jit(
out_idx=[-1], pass_configs={
out_idx=[-1],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
},
)
def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H):
scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e)
scale = (1.0 / (dim + pe_dim)) ** 0.5 * 1.44269504 # log2(e)
dtype = "float16"
q_dtype = "float8_e4m3"
accum_dtype = "float"
......@@ -22,11 +24,11 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
@T.prim_func
def main_no_split(
Q: T.Tensor([batch, heads, dim], dtype),
Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], q_dtype),
K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype),
Q: T.Tensor([batch, heads, dim], dtype),
Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], q_dtype),
K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype),
):
with T.Kernel(batch, heads // min(block_H, kv_group_num), threads=256) as (bx, by):
Q_shared = T.alloc_shared([block_H, dim], dtype)
......@@ -46,31 +48,27 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
cur_kv_head = by // (kv_group_num // block_H)
T.use_swizzle(10)
T.annotate_layout({
O_shared: tilelang.layout.make_swizzled_layout(O_shared),
})
T.copy(Q[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_shared)
T.copy(Q_pe[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_pe_shared)
T.annotate_layout(
{
O_shared: tilelang.layout.make_swizzled_layout(O_shared),
}
)
T.copy(Q[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_shared)
T.copy(Q_pe[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :], Q_pe_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
T.disable_warp_group_reg_alloc()
loop_range = T.ceildiv(seqlen_kv, block_N)
for k in T.Pipelined(loop_range, num_stages=2):
T.copy(KV[bx, k * block_N:(k + 1) * block_N, cur_kv_head, :], qKV_shared)
T.copy(K_pe[bx, k * block_N:(k + 1) * block_N, cur_kv_head, :], K_pe_shared)
T.copy(KV[bx, k * block_N : (k + 1) * block_N, cur_kv_head, :], qKV_shared)
T.copy(K_pe[bx, k * block_N : (k + 1) * block_N, cur_kv_head, :], K_pe_shared)
T.copy(qKV_shared, KV_shared)
T.clear(acc_s)
T.gemm(
Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
T.gemm(
Q_pe_shared,
K_pe_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullCol)
T.gemm(Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
T.gemm(Q_pe_shared, K_pe_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
......@@ -90,7 +88,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :])
T.copy(O_shared, Output[bx, by * VALID_BLOCK_H : (by + 1) * VALID_BLOCK_H, :])
return main_no_split
......@@ -108,42 +106,35 @@ def ref_program(q, q_pe, kv, k_pe):
dim = q.shape[-1]
pe_dim = q_pe.shape[-1]
num_head_groups = q.shape[1] // kv.shape[2]
scale = (dim + pe_dim)**0.5
q = rearrange(
q, 'b (h g) d -> b g h d', g=num_head_groups) # [batch_size, num_head_groups, groups, dim]
scale = (dim + pe_dim) ** 0.5
q = rearrange(q, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, dim]
q_pe = rearrange(
q_pe, 'b (h g) d -> b g h d',
g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim]
q_pe = rearrange(q_pe, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim]
kv = rearrange(kv, 'b n h d -> b h n d') # [batch_size, groups, seqlen_kv, dim]
kv = rearrange(kv, "b n h d -> b h n d") # [batch_size, groups, seqlen_kv, dim]
k_pe = rearrange(k_pe, 'b n h d -> b h n d') # [batch_size, num_head_groups, groups, pe_dim]
k_pe = rearrange(k_pe, "b n h d -> b h n d") # [batch_size, num_head_groups, groups, pe_dim]
query = torch.concat([q, q_pe], dim=-1)
key = torch.concat([kv, k_pe], dim=-1)
scores = einsum(
query, key,
'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, groups, seqlen_kv]
scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, groups, seqlen_kv]
attention = F.softmax(
scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv]
attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv]
out = einsum(attention, kv,
'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, groups, dim]
out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim]
out = einsum(attention, kv, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, groups, dim]
out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim]
return out
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=128, help='batch size')
parser.add_argument('--heads', type=int, default=128, help='q heads number')
parser.add_argument('--kv_heads', type=int, default=1, help='kv heads number')
parser.add_argument('--kv_ctx', type=int, default=8192, help='kv context length')
parser.add_argument('--dim', type=int, default=512, help='head dim')
parser.add_argument('--pe_dim', type=int, default=64, help='pe head dim')
parser.add_argument("--batch", type=int, default=128, help="batch size")
parser.add_argument("--heads", type=int, default=128, help="q heads number")
parser.add_argument("--kv_heads", type=int, default=1, help="kv heads number")
parser.add_argument("--kv_ctx", type=int, default=8192, help="kv context length")
parser.add_argument("--dim", type=int, default=512, help="head dim")
parser.add_argument("--pe_dim", type=int, default=64, help="pe head dim")
args = parser.parse_args()
batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim
qk_flops = 2 * batch * heads * kv_ctx * (dim + pe_dim)
......
......@@ -11,7 +11,7 @@ def flash_split_ref(Q, Q_pe, KV, K_pe):
block_N = 64
seqlen_kv = KV.size(1)
scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e)
scale = (1.0 / (dim + pe_dim)) ** 0.5 * 1.44269504 # log2(e)
acc_s = torch.empty((batch, nheads, block_N), device="cuda", dtype=torch.float)
acc_s_cast = torch.empty((batch, nheads, block_N), device="cuda", dtype=torch.float16)
acc_o = torch.empty((batch, nheads, dim), device="cuda", dtype=torch.float)
......@@ -31,18 +31,20 @@ def flash_split_ref(Q, Q_pe, KV, K_pe):
for ks in range(num_split):
acc_o.fill_(0)
logsum.fill_(0)
scores_max.fill_(float('-inf'))
scores_max_prev.fill_(float('-inf'))
scores_max.fill_(float("-inf"))
scores_max_prev.fill_(float("-inf"))
for i in range(int((seqlen_kv // num_split) / block_N)):
acc_s.fill_(0)
acc_s = torch.einsum('bhd,bkhd->bhk', Q_,
KV_[:, (seqlen_kv // num_split) * ks +
i * block_N:(seqlen_kv // num_split) * ks +
(i + 1) * block_N, :, :]) # [batch, nheads, block_N]
acc_s = torch.einsum(
"bhd,bkhd->bhk",
Q_,
KV_[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :, :],
) # [batch, nheads, block_N]
acc_s += torch.einsum(
'bhd,bkhd->bhk', Q_pe_,
K_pe_[:, (seqlen_kv // num_split) * ks + i * block_N:(seqlen_kv // num_split) * ks +
(i + 1) * block_N, :, :])
"bhd,bkhd->bhk",
Q_pe_,
K_pe_[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :, :],
)
scores_max_prev = scores_max
scores_max = acc_s.max(dim=-1, keepdim=False).values # [batch, nheads]
scores_scale = torch.exp2(scores_max_prev - scores_max) # [batch, nheads]
......@@ -50,9 +52,10 @@ def flash_split_ref(Q, Q_pe, KV, K_pe):
acc_s = torch.exp2(acc_s - scores_max[:, :, None])
acc_s_cast = acc_s.to(torch.float16) # [batch, nheads, block_N]
acc_o += torch.einsum(
'bhk,bkhd->bhd', acc_s_cast,
KV_[:, (seqlen_kv // num_split) * ks + i * block_N:(seqlen_kv // num_split) * ks +
(i + 1) * block_N, :, :])
"bhk,bkhd->bhd",
acc_s_cast,
KV_[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :, :],
)
scores_sum = acc_s.sum(dim=-1, keepdim=False)
logsum = logsum * scores_scale + scores_sum
acc_o /= logsum[:, :, None]
......
......@@ -14,21 +14,44 @@ from fla.ops.utils import prepare_token_indices
from fla.utils import autocast_custom_fwd, contiguous
@triton.heuristics({
'USE_OFFSETS': lambda args: args['offsets'] is not None,
'USE_BLOCK_COUNTS': lambda args: isinstance(args['block_counts'], torch.Tensor),
})
@triton.heuristics(
{
"USE_OFFSETS": lambda args: args["offsets"] is not None,
"USE_BLOCK_COUNTS": lambda args: isinstance(args["block_counts"], torch.Tensor),
}
)
@triton.autotune(
configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1]],
key=['BS', 'BK', 'BV'],
key=["BS", "BK", "BV"],
)
@triton.jit
def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, block_indices,
block_counts, offsets, token_indices, T, H: tl.constexpr,
HQ: tl.constexpr, G: tl.constexpr, K: tl.constexpr, V: tl.constexpr,
S: tl.constexpr, BS: tl.constexpr, WS: tl.constexpr, BK: tl.constexpr,
BV: tl.constexpr, USE_OFFSETS: tl.constexpr,
USE_BLOCK_COUNTS: tl.constexpr):
def parallel_nsa_fwd_kernel(
q,
k,
v,
o_slc,
o_swa,
lse_slc,
lse_swa,
scale,
block_indices,
block_counts,
offsets,
token_indices,
T,
H: tl.constexpr,
HQ: tl.constexpr,
G: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
S: tl.constexpr,
BS: tl.constexpr,
WS: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
USE_OFFSETS: tl.constexpr,
USE_BLOCK_COUNTS: tl.constexpr,
):
i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_b, i_h = i_bh // H, i_bh % H
......@@ -40,20 +63,18 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc
NS = S
p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK),
(1, 0))
p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0))
# the Q block is kept in the shared memory throughout the whole kernel
# [G, BK]
b_q = tl.load(p_q, boundary_check=(0, 1))
b_q = (b_q * scale).to(b_q.dtype)
p_o_slc = tl.make_block_ptr(o_slc + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV),
(G, BV), (1, 0))
p_o_slc = tl.make_block_ptr(o_slc + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0))
p_lse_slc = lse_slc + (bos + i_t) * HQ + i_h * G + tl.arange(0, G)
# [G, BV]
b_o_slc = tl.zeros([G, BV], dtype=tl.float32)
b_m_slc = tl.full([G], float('-inf'), dtype=tl.float32)
b_m_slc = tl.full([G], float("-inf"), dtype=tl.float32)
b_acc_slc = tl.zeros([G], dtype=tl.float32)
for i in range(NS):
i_s = tl.load(block_indices + i).to(tl.int32) * BS
......@@ -66,7 +87,7 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc
b_v_slc = tl.load(p_v_slc, boundary_check=(0, 1))
# [G, BS]
b_s_slc = tl.dot(b_q, b_k_slc)
b_s_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_slc, float('-inf'))
b_s_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_slc, float("-inf"))
# [G]
b_m_slc, b_mp_slc = tl.maximum(b_m_slc, tl.max(b_s_slc, 1)), b_m_slc
......@@ -87,7 +108,6 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc
class ParallelNSAFunction(torch.autograd.Function):
@staticmethod
@contiguous
@autocast_custom_fwd
......@@ -100,8 +120,7 @@ class ParallelNSAFunction(torch.autograd.Function):
# [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
token_indices = prepare_token_indices(offsets) if offsets is not None else None
o, lse = parallel_nsa_fwd(
q=q, k=k, v=v, block_indices=block_indices, block_size=block_size, scale=scale)
o, lse = parallel_nsa_fwd(q=q, k=k, v=v, block_indices=block_indices, block_size=block_size, scale=scale)
ctx.save_for_backward(q, k, v, o, lse)
ctx.block_indices = block_indices
ctx.block_size = block_size
......@@ -172,7 +191,6 @@ def parallel_nsa_fwd(
@torch.compile
class ParallelNSAFunction(torch.autograd.Function):
@staticmethod
@contiguous
@autocast_custom_fwd
......@@ -195,7 +213,8 @@ class ParallelNSAFunction(torch.autograd.Function):
window_size=window_size,
scale=scale,
offsets=offsets,
token_indices=token_indices)
token_indices=token_indices,
)
ctx.save_for_backward(q, k, v, o_slc, lse_slc, o_swa, lse_swa)
ctx.block_indices = block_indices
ctx.block_counts = block_counts
......@@ -207,18 +226,20 @@ class ParallelNSAFunction(torch.autograd.Function):
return o_slc.to(q.dtype), o_swa.to(q.dtype) if o_swa is not None else o_swa
def parallel_nsa(q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g_slc: torch.Tensor,
g_swa: torch.Tensor,
block_indices: torch.LongTensor,
block_counts: Optional[Union[torch.LongTensor, int]] = None,
block_size: int = 64,
window_size: int = 0,
scale: Optional[float] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
head_first: bool = False) -> torch.Tensor:
def parallel_nsa(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g_slc: torch.Tensor,
g_swa: torch.Tensor,
block_indices: torch.LongTensor,
block_counts: Optional[Union[torch.LongTensor, int]] = None,
block_size: int = 64,
window_size: int = 0,
scale: Optional[float] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
head_first: bool = False,
) -> torch.Tensor:
r"""
Args:
q (torch.Tensor):
......@@ -258,44 +279,44 @@ def parallel_nsa(q: torch.Tensor,
Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`.
"""
if scale is None:
scale = k.shape[-1]**-0.5
scale = k.shape[-1] ** -0.5
if cu_seqlens is not None:
assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided"
if head_first:
q, k, v, block_indices = map(lambda x: rearrange(x, 'b h t d -> b t h d'),
(q, k, v, block_indices))
g_slc, g_swa = map(lambda x: rearrange(x, 'b h t -> b t h'), (g_slc, g_swa))
q, k, v, block_indices = map(lambda x: rearrange(x, "b h t d -> b t h d"), (q, k, v, block_indices))
g_slc, g_swa = map(lambda x: rearrange(x, "b h t -> b t h"), (g_slc, g_swa))
if isinstance(block_counts, torch.Tensor):
block_counts = rearrange(block_counts, 'b h t -> b t h')
block_counts = rearrange(block_counts, "b h t -> b t h")
assert q.shape[2] % (k.shape[2] * 16) == 0, "Group size must be a multiple of 16 in NSA"
if isinstance(block_counts, int):
block_indices = block_indices[:, :, :, :block_counts]
block_counts = None
o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size,
window_size, scale, cu_seqlens)
o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, window_size, scale, cu_seqlens)
if window_size > 0:
o = torch.addcmul(o_slc * g_slc.unsqueeze(-1), o_swa, g_swa.unsqueeze(-1))
else:
o = o_slc * g_slc.unsqueeze(-1)
if head_first:
o = rearrange(o, 'b t h d -> b h t d')
o = rearrange(o, "b t h d -> b h t d")
return o
def naive_nsa(q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g_slc: torch.Tensor,
g_swa: torch.Tensor,
block_indices: torch.LongTensor,
block_counts: Optional[Union[torch.LongTensor, int]] = None,
block_size: int = 64,
window_size: int = 0,
scale: Optional[float] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
head_first: bool = False) -> torch.Tensor:
def naive_nsa(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g_slc: torch.Tensor,
g_swa: torch.Tensor,
block_indices: torch.LongTensor,
block_counts: Optional[Union[torch.LongTensor, int]] = None,
block_size: int = 64,
window_size: int = 0,
scale: Optional[float] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
head_first: bool = False,
) -> torch.Tensor:
r"""
Args:
q (torch.Tensor):
......@@ -335,26 +356,24 @@ def naive_nsa(q: torch.Tensor,
Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`.
"""
if scale is None:
scale = k.shape[-1]**-0.5
scale = k.shape[-1] ** -0.5
if cu_seqlens is not None:
assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided"
if head_first:
raise RuntimeError(
"Sequences with variable lengths are not supported for head-first mode")
raise RuntimeError("Sequences with variable lengths are not supported for head-first mode")
if head_first:
q, k, v, block_indices = map(lambda x: rearrange(x, 'b h t d -> b t h d'),
(q, k, v, block_indices))
g_slc, g_swa = map(lambda x: rearrange(x, 'b h t -> b t h'), (g_slc, g_swa))
q, k, v, block_indices = map(lambda x: rearrange(x, "b h t d -> b t h d"), (q, k, v, block_indices))
g_slc, g_swa = map(lambda x: rearrange(x, "b h t -> b t h"), (g_slc, g_swa))
if isinstance(block_counts, torch.Tensor):
block_counts = rearrange(block_counts, 'b h t -> b t h')
block_counts = rearrange(block_counts, "b h t -> b t h")
dtype = q.dtype
G = q.shape[2] // k.shape[2]
BS = block_size
S = block_indices.shape[-1]
k, v, block_indices = (repeat(x, 'b t h d -> b t (h g) d', g=G) for x in (k, v, block_indices))
k, v, block_indices = (repeat(x, "b t h d -> b t (h g) d", g=G) for x in (k, v, block_indices))
if isinstance(block_counts, torch.Tensor):
block_counts = repeat(block_counts, 'b t h -> b t (h g)', g=G)
block_counts = repeat(block_counts, "b t h -> b t (h g)", g=G)
c = torch.arange(S).repeat_interleave(BS).unsqueeze(1).expand(-1, q.shape[2]).to(q.device)
q, k, v = map(lambda x: x.float(), (q, k, v))
......@@ -364,14 +383,11 @@ def naive_nsa(q: torch.Tensor,
if cu_seqlens is None:
varlen = False
B, T = q.shape[:2]
cu_seqlens = torch.cat(
[block_indices.new_tensor(range(0, B * T, T)),
block_indices.new_tensor([B * T])])
cu_seqlens = torch.cat([block_indices.new_tensor(range(0, B * T, T)), block_indices.new_tensor([B * T])])
for i in range(len(cu_seqlens) - 1):
if not varlen:
q_b, k_b, v_b, g_slc_b, g_swa_b, i_b = q[i], k[i], v[i], g_slc[i], g_swa[
i], block_indices[i]
q_b, k_b, v_b, g_slc_b, g_swa_b, i_b = q[i], k[i], v[i], g_slc[i], g_swa[i], block_indices[i]
if isinstance(block_counts, torch.Tensor):
s_b = block_counts[i]
else:
......@@ -379,10 +395,10 @@ def naive_nsa(q: torch.Tensor,
else:
T = cu_seqlens[i + 1] - cu_seqlens[i]
q_b, k_b, v_b, g_slc_b, g_swa_b, i_b = map(
lambda x: x[0][cu_seqlens[i]:cu_seqlens[i + 1]],
(q, k, v, g_slc, g_swa, block_indices))
lambda x: x[0][cu_seqlens[i] : cu_seqlens[i + 1]], (q, k, v, g_slc, g_swa, block_indices)
)
if isinstance(block_counts, torch.Tensor):
s_b = block_counts[0][cu_seqlens[i]:cu_seqlens[i + 1]]
s_b = block_counts[0][cu_seqlens[i] : cu_seqlens[i + 1]]
else:
s_b = block_counts
......@@ -404,71 +420,58 @@ def naive_nsa(q: torch.Tensor,
else:
s_i = s_b
# [S*BS, HQ, -1]
k_i_slc, v_i_slc = map(
lambda x: x.gather(
0,
i_i.clamp(0, T - 1).unsqueeze(-1).expand(*i_i.shape, x.shape[-1])), (k_b, v_b))
k_i_slc, v_i_slc = map(lambda x: x.gather(0, i_i.clamp(0, T - 1).unsqueeze(-1).expand(*i_i.shape, x.shape[-1])), (k_b, v_b))
# [S*BS, HQ]
attn_slc = torch.einsum('h d, n h d -> n h', q_i, k_i_slc).masked_fill(
torch.logical_or(i_i < 0, i_i > i_q) |
(c >= s_i if block_counts is not None else False), float('-inf')).softmax(0)
attn_slc = (
torch.einsum("h d, n h d -> n h", q_i, k_i_slc)
.masked_fill(torch.logical_or(i_i < 0, i_i > i_q) | (c >= s_i if block_counts is not None else False), float("-inf"))
.softmax(0)
)
if not varlen:
o_slc[i, i_q] = torch.einsum('n h, n h v -> h v', attn_slc,
v_i_slc) * g_slc_i.unsqueeze(-1)
o_slc[i, i_q] = torch.einsum("n h, n h v -> h v", attn_slc, v_i_slc) * g_slc_i.unsqueeze(-1)
else:
o_slc[0][cu_seqlens[i] + i_q] = torch.einsum('n h, n h v -> h v', attn_slc,
v_i_slc) * g_slc_i.unsqueeze(-1)
o_slc[0][cu_seqlens[i] + i_q] = torch.einsum("n h, n h v -> h v", attn_slc, v_i_slc) * g_slc_i.unsqueeze(-1)
if window_size > 0:
k_i_swa, v_i_swa = map(lambda x: x[max(0, i_q - window_size + 1):i_q + 1],
(k_b, v_b))
attn_swa = torch.einsum('h d, n h d -> n h', q_i, k_i_swa).softmax(0)
k_i_swa, v_i_swa = map(lambda x: x[max(0, i_q - window_size + 1) : i_q + 1], (k_b, v_b))
attn_swa = torch.einsum("h d, n h d -> n h", q_i, k_i_swa).softmax(0)
if not varlen:
o_swa[i, i_q] = torch.einsum('n h, n h v -> h v', attn_swa,
v_i_swa) * g_swa_i.unsqueeze(-1)
o_swa[i, i_q] = torch.einsum("n h, n h v -> h v", attn_swa, v_i_swa) * g_swa_i.unsqueeze(-1)
else:
o_swa[0][cu_seqlens[i] + i_q] = torch.einsum('n h, n h v -> h v', attn_swa,
v_i_swa) * g_swa_i.unsqueeze(-1)
o_swa[0][cu_seqlens[i] + i_q] = torch.einsum("n h, n h v -> h v", attn_swa, v_i_swa) * g_swa_i.unsqueeze(-1)
if head_first:
o_slc = rearrange(o_slc, 'b t h d -> b h t d')
o_swa = rearrange(o_swa, 'b t h d -> b h t d')
o_slc = rearrange(o_slc, "b t h d -> b h t d")
o_swa = rearrange(o_swa, "b t h d -> b h t d")
return o_slc.to(dtype) + o_swa.to(dtype) if o_swa is not None else o_slc.to(dtype)
def get_configs():
import itertools
iter_params = dict(
block_T=[128, 256, 512],
num_stages=[0, 1, 2, 4, 5],
threads=[32, 64, 128, 256, 512],
)
return [{
k: v for k, v in zip(iter_params, values)
} for values in itertools.product(*iter_params.values())]
return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())]
@tilelang.autotune(configs=get_configs(),)
@tilelang.autotune(
configs=get_configs(),
)
@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
def tilelang_sparse_attention(batch,
heads,
seq_len,
dim,
is_causal,
scale=None,
block_size=64,
groups=1,
selected_blocks=16,
block_T=128,
num_stages=2,
threads=32):
}
)
def tilelang_sparse_attention(
batch, heads, seq_len, dim, is_causal, scale=None, block_size=64, groups=1, selected_blocks=16, block_T=128, num_stages=2, threads=32
):
if scale is None:
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
else:
scale = scale * 1.44269504 # log2(e)
......@@ -493,11 +496,11 @@ def tilelang_sparse_attention(batch,
@T.prim_func
def tilelang_sparse_attention(
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype),
BlockIndices: T.Tensor(block_indices_shape, block_indices_dtype),
Output: T.Tensor(q_shape, dtype),
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype),
BlockIndices: T.Tensor(block_indices_shape, block_indices_dtype),
Output: T.Tensor(q_shape, dtype),
):
with T.Kernel(seq_len, NV, batch * head_kv, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([G, BK], dtype)
......@@ -520,7 +523,7 @@ def tilelang_sparse_attention(batch,
i_b, i_h = i_bh // head_kv, i_bh % head_kv
NS = S
T.copy(Q[i_b, i_t, i_h * G:(i_h + 1) * G, :], Q_shared)
T.copy(Q[i_b, i_t, i_h * G : (i_h + 1) * G, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
......@@ -530,21 +533,15 @@ def tilelang_sparse_attention(batch,
i_s = BlockIndices[i_b, i_t, i_h, i] * BS
if i_s <= i_t and i_s >= 0:
# [BS, BK]
T.copy(K[i_b, i_s:i_s + BS, i_h, :], K_shared)
T.copy(K[i_b, i_s : i_s + BS, i_h, :], K_shared)
if is_causal:
for i, j in T.Parallel(G, BS):
acc_s[i, j] = T.if_then_else(i_t >= (i_s + j), 0,
-T.infinity(acc_s.dtype))
acc_s[i, j] = T.if_then_else(i_t >= (i_s + j), 0, -T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
T.gemm(
Q_shared,
K_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
# Softmax
T.copy(scores_max, scores_max_prev)
......@@ -564,45 +561,33 @@ def tilelang_sparse_attention(batch,
acc_o[i, j] *= scores_scale[i]
# V * softmax(Q * K)
T.copy(V[i_b, i_s:i_s + BS, i_h, i_v * BV:(i_v + 1) * BV], V_shared)
T.copy(V[i_b, i_s : i_s + BS, i_h, i_v * BV : (i_v + 1) * BV], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(G, BV):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[i_b, i_t, i_h * G:(i_h + 1) * G, i_v * BV:(i_v + 1) * BV])
T.copy(O_shared, Output[i_b, i_t, i_h * G : (i_h + 1) * G, i_v * BV : (i_v + 1) * BV])
return tilelang_sparse_attention
def generate_block_indices(batch, seq_len, heads, selected_blocks, block_size):
"""Generate random block indices for the benchmark."""
block_indices = torch.full((batch, seq_len, heads, selected_blocks),
seq_len,
dtype=torch.long,
device='cuda')
block_indices = torch.full((batch, seq_len, heads, selected_blocks), seq_len, dtype=torch.long, device="cuda")
for b in range(batch):
for t in range(seq_len):
for h in range(heads):
i_i = torch.randperm(max(1, (t // block_size)))[:selected_blocks]
block_indices[b, t, h, :len(i_i)] = i_i
block_indices[b, t, h, : len(i_i)] = i_i
return block_indices.sort(-1)[0]
def benchmark_nsa(batch_size,
seq_len,
heads,
head_query,
dim,
selected_blocks,
block_size,
dtype,
scale,
warmup=10,
iterations=100,
validate=False):
def benchmark_nsa(
batch_size, seq_len, heads, head_query, dim, selected_blocks, block_size, dtype, scale, warmup=10, iterations=100, validate=False
):
"""Benchmark the TileLang Sparse Attention implementation."""
# Set random seed for reproducibility
......@@ -628,14 +613,13 @@ def benchmark_nsa(batch_size,
print(f"Profiler latency: {profiler_latency} ms")
# Create input tensors
Q = torch.randn((batch_size, seq_len, head_query, dim), dtype=dtype, device='cuda')
K = torch.randn((batch_size, seq_len, heads, dim), dtype=dtype, device='cuda')
V = torch.randn((batch_size, seq_len, heads, dim), dtype=dtype, device='cuda')
out = torch.empty((batch_size, seq_len, head_query, dim), dtype=dtype, device='cuda')
Q = torch.randn((batch_size, seq_len, head_query, dim), dtype=dtype, device="cuda")
K = torch.randn((batch_size, seq_len, heads, dim), dtype=dtype, device="cuda")
V = torch.randn((batch_size, seq_len, heads, dim), dtype=dtype, device="cuda")
out = torch.empty((batch_size, seq_len, head_query, dim), dtype=dtype, device="cuda")
# Generate block indices
block_indices = generate_block_indices(batch_size, seq_len, heads, selected_blocks,
block_size).to(torch.int32)
block_indices = generate_block_indices(batch_size, seq_len, heads, selected_blocks, block_size).to(torch.int32)
# Warmup
for _ in range(warmup):
......@@ -666,10 +650,9 @@ def benchmark_nsa(batch_size,
# Validate result against reference if requested
if validate:
g_slc = torch.ones((batch_size, seq_len, head_query), dtype=dtype, device='cuda')
g_swa = torch.ones((batch_size, seq_len, head_query), dtype=dtype, device='cuda')
block_counts = torch.randint(
1, selected_blocks + 1, (batch_size, seq_len, heads), device='cuda')
g_slc = torch.ones((batch_size, seq_len, head_query), dtype=dtype, device="cuda")
g_swa = torch.ones((batch_size, seq_len, head_query), dtype=dtype, device="cuda")
block_counts = torch.randint(1, selected_blocks + 1, (batch_size, seq_len, heads), device="cuda")
ref = naive_nsa(
q=Q,
......@@ -700,22 +683,13 @@ def benchmark_nsa(batch_size,
"head_query": head_query,
"dim": dim,
"selected_blocks": selected_blocks,
"block_size": block_size
"block_size": block_size,
}
def benchmark_triton_nsa(batch_size,
seq_len,
heads,
head_query,
dim,
selected_blocks,
block_size,
dtype,
scale,
warmup=10,
iterations=100,
validate=False):
def benchmark_triton_nsa(
batch_size, seq_len, heads, head_query, dim, selected_blocks, block_size, dtype, scale, warmup=10, iterations=100, validate=False
):
"""Benchmark the Triton-based TileLang Sparse Attention implementation."""
# Set random seed for reproducibility
......@@ -723,18 +697,17 @@ def benchmark_triton_nsa(batch_size,
torch.random.manual_seed(0)
# Create input tensors
Q = torch.randn((batch_size, seq_len, head_query, dim), dtype=dtype, device='cuda')
K = torch.randn((batch_size, seq_len, heads, dim), dtype=dtype, device='cuda')
V = torch.randn((batch_size, seq_len, heads, dim), dtype=dtype, device='cuda')
g_slc = torch.ones((batch_size, seq_len, head_query), dtype=dtype, device='cuda')
g_swa = torch.ones((batch_size, seq_len, head_query), dtype=dtype, device='cuda')
Q = torch.randn((batch_size, seq_len, head_query, dim), dtype=dtype, device="cuda")
K = torch.randn((batch_size, seq_len, heads, dim), dtype=dtype, device="cuda")
V = torch.randn((batch_size, seq_len, heads, dim), dtype=dtype, device="cuda")
g_slc = torch.ones((batch_size, seq_len, head_query), dtype=dtype, device="cuda")
g_swa = torch.ones((batch_size, seq_len, head_query), dtype=dtype, device="cuda")
# Generate block indices
block_indices = generate_block_indices(batch_size, seq_len, heads, selected_blocks, block_size)
block_counts = torch.randint(
1, selected_blocks + 1, (batch_size, seq_len, heads), device='cuda')
o_slc = torch.empty((batch_size, seq_len, head_query, dim), dtype=dtype, device='cuda')
lse_slc = torch.empty((batch_size, seq_len, head_query), dtype=torch.float, device='cuda')
block_counts = torch.randint(1, selected_blocks + 1, (batch_size, seq_len, heads), device="cuda")
o_slc = torch.empty((batch_size, seq_len, head_query, dim), dtype=dtype, device="cuda")
lse_slc = torch.empty((batch_size, seq_len, head_query), dtype=torch.float, device="cuda")
# Warmup
for _ in range(warmup):
......@@ -750,7 +723,8 @@ def benchmark_triton_nsa(batch_size,
block_counts=block_counts,
block_size=block_size,
window_size=0,
scale=scale)
scale=scale,
)
# Synchronize before timing
torch.cuda.synchronize()
......@@ -770,7 +744,8 @@ def benchmark_triton_nsa(batch_size,
block_counts=block_counts,
block_size=block_size,
window_size=0,
scale=scale)
scale=scale,
)
torch.cuda.synchronize()
end_time = time.time()
......@@ -815,54 +790,28 @@ def benchmark_triton_nsa(batch_size,
"head_query": head_query,
"dim": dim,
"selected_blocks": selected_blocks,
"block_size": block_size
"block_size": block_size,
}
def run_benchmark_suite(impl='all'):
def run_benchmark_suite(impl="all"):
"""Run a suite of benchmarks with different configurations."""
# Define configurations to benchmark
configs = [
# Small model config - Note: head_query must be a multiple of heads*16 for Triton
{
"batch_size": 2,
"seq_len": 1024,
"heads": 8,
"head_query": 8 * 16,
"dim": 64,
"selected_blocks": 8,
"block_size": 32
},
{"batch_size": 2, "seq_len": 1024, "heads": 8, "head_query": 8 * 16, "dim": 64, "selected_blocks": 8, "block_size": 32},
# Medium model config
{
"batch_size": 2,
"seq_len": 2048,
"heads": 16,
"head_query": 16 * 16,
"dim": 64,
"selected_blocks": 16,
"block_size": 64
},
{"batch_size": 2, "seq_len": 2048, "heads": 16, "head_query": 16 * 16, "dim": 64, "selected_blocks": 16, "block_size": 64},
# Large model config
{
"batch_size": 1,
"seq_len": 4096,
"heads": 32,
"head_query": 32 * 16,
"dim": 128,
"selected_blocks": 32,
"block_size": 128
},
{"batch_size": 1, "seq_len": 4096, "heads": 32, "head_query": 32 * 16, "dim": 128, "selected_blocks": 32, "block_size": 128},
]
results = []
for config in configs:
print(f"Running benchmark with config: {config}")
if impl in ['all', 'tilelang']:
if impl in ["all", "tilelang"]:
print("Benchmarking TileLang implementation:")
result = benchmark_nsa(
batch_size=config["batch_size"],
......@@ -874,12 +823,13 @@ def run_benchmark_suite(impl='all'):
block_size=config["block_size"],
dtype=torch.float16,
scale=0.1,
validate=False)
validate=False,
)
results.append({"impl": "tilelang", **result})
print(f"Average time: {result['avg_time_ms']:.2f} ms")
print(f"Performance: {result['tflops']:.2f} TFLOPs")
if impl in ['all', 'triton']:
if impl in ["all", "triton"]:
print("Benchmarking Triton implementation:")
result = benchmark_triton_nsa(
batch_size=config["batch_size"],
......@@ -891,19 +841,24 @@ def run_benchmark_suite(impl='all'):
block_size=config["block_size"],
dtype=torch.float16,
scale=0.1,
validate=False)
validate=False,
)
results.append({"impl": "triton", **result})
print(f"Average time: {result['avg_time_ms']:.2f} ms")
print(f"Performance: {result['tflops']:.2f} TFLOPs")
if impl in ['all']:
if impl in ["all"]:
# Print comparison if both implementations were run
tilelang_result = next(
r for r in results if r["impl"] == "tilelang" and
r["batch_size"] == config["batch_size"] and r["seq_len"] == config["seq_len"])
r
for r in results
if r["impl"] == "tilelang" and r["batch_size"] == config["batch_size"] and r["seq_len"] == config["seq_len"]
)
triton_result = next(
r for r in results if r["impl"] == "triton" and
r["batch_size"] == config["batch_size"] and r["seq_len"] == config["seq_len"])
r
for r in results
if r["impl"] == "triton" and r["batch_size"] == config["batch_size"] and r["seq_len"] == config["seq_len"]
)
speedup = tilelang_result["avg_time_ms"] / triton_result["avg_time_ms"]
print(f"Speedup (Triton vs TileLang): {speedup:.2f}x")
......@@ -921,8 +876,7 @@ if __name__ == "__main__":
parser.add_argument("--dim", type=int, default=128, help="Head dimension")
parser.add_argument("--selected_blocks", type=int, default=16, help="Number of selected blocks")
parser.add_argument("--block_size", type=int, default=32, help="Block size")
parser.add_argument(
"--dtype", type=str, default="float16", help="Data type (float16 or float32)")
parser.add_argument("--dtype", type=str, default="float16", help="Data type (float16 or float32)")
parser.add_argument("--scale", type=float, default=0.1, help="Attention scale factor")
parser.add_argument("--iterations", type=int, default=100, help="Number of iterations")
parser.add_argument("--warmup", type=int, default=10, help="Warmup iterations")
......@@ -933,7 +887,8 @@ if __name__ == "__main__":
type=str,
default="all",
choices=["tilelang", "triton", "all"],
help="Implementation to benchmark (tilelang, triton, or all)")
help="Implementation to benchmark (tilelang, triton, or all)",
)
args = parser.parse_args()
......@@ -941,8 +896,7 @@ if __name__ == "__main__":
if args.impl in ["triton", "all"] and args.head_query % (args.heads * 16) != 0:
# Adjust head_query to nearest valid value
args.head_query = ((args.head_query // (args.heads * 16)) + 1) * (args.heads * 16)
print(
f"Adjusted head_query to {args.head_query} to be compatible with Triton implementation")
print(f"Adjusted head_query to {args.head_query} to be compatible with Triton implementation")
if args.suite:
run_benchmark_suite(impl=args.impl)
......@@ -963,12 +917,14 @@ if __name__ == "__main__":
scale=args.scale,
warmup=args.warmup,
iterations=args.iterations,
validate=args.validate)
validate=args.validate,
)
print("\nBenchmark Results (TileLang):")
print(
f"Configuration: batch={args.batch}, seq_len={args.seq_len}, heads={args.heads}, " +
f"head_query={args.head_query}, dim={args.dim}, blocks={args.selected_blocks}, " +
f"block_size={args.block_size}")
f"Configuration: batch={args.batch}, seq_len={args.seq_len}, heads={args.heads}, "
+ f"head_query={args.head_query}, dim={args.dim}, blocks={args.selected_blocks}, "
+ f"block_size={args.block_size}"
)
print(f"Average time: {result['avg_time_ms']:.2f} ms")
print(f"Performance: {result['tflops']:.2f} TFLOPs")
......@@ -986,11 +942,13 @@ if __name__ == "__main__":
scale=args.scale,
warmup=args.warmup,
iterations=args.iterations,
validate=args.validate)
validate=args.validate,
)
print("\nBenchmark Results (Triton):")
print(
f"Configuration: batch={args.batch}, seq_len={args.seq_len}, heads={args.heads}, " +
f"head_query={args.head_query}, dim={args.dim}, blocks={args.selected_blocks}, " +
f"block_size={args.block_size}")
f"Configuration: batch={args.batch}, seq_len={args.seq_len}, heads={args.heads}, "
+ f"head_query={args.head_query}, dim={args.dim}, blocks={args.selected_blocks}, "
+ f"block_size={args.block_size}"
)
print(f"Average time: {result['avg_time_ms']:.2f} ms")
print(f"Performance: {result['tflops']:.2f} TFLOPs")
......@@ -7,6 +7,7 @@ import torch
import triton
import fla
if parse(fla.__version__) < parse("0.2.1"):
from fla.ops.common.utils import prepare_token_indices
else:
......@@ -22,7 +23,8 @@ import tilelang
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
}
)
def tilelang_kernel_fwd(
batch,
heads,
......@@ -34,11 +36,10 @@ def tilelang_kernel_fwd(
groups=1,
selected_blocks=16,
):
from tilelang import language as T
if scale is None:
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
else:
scale = scale * 1.44269504 # log2(e)
......@@ -67,12 +68,12 @@ def tilelang_kernel_fwd(
@T.prim_func
def native_sparse_attention(
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype),
BlockIndices: T.Tensor(block_indices_shape, block_indices_dtype),
O_slc: T.Tensor(o_slc_shape, dtype),
LSE_slc: T.Tensor(lse_slc_shape, accum_dtype),
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype),
BlockIndices: T.Tensor(block_indices_shape, block_indices_dtype),
O_slc: T.Tensor(o_slc_shape, dtype),
LSE_slc: T.Tensor(lse_slc_shape, accum_dtype),
):
with T.Kernel(seq_len, NV, batch * head_kv, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([G, BK], dtype)
......@@ -93,7 +94,7 @@ def tilelang_kernel_fwd(
i_b, i_h = i_bh // head_kv, i_bh % head_kv
NS = S
T.copy(Q[i_b, i_t, i_h * G:(i_h + 1) * G, :], Q_shared)
T.copy(Q[i_b, i_t, i_h * G : (i_h + 1) * G, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
......@@ -103,12 +104,11 @@ def tilelang_kernel_fwd(
i_s = BlockIndices[i_b, i_t, i_h, i] * BS
if i_s <= i_t and i_s >= 0:
# [BS, BK]
T.copy(K[i_b, i_s:i_s + BS, i_h, :], K_shared)
T.copy(K[i_b, i_s : i_s + BS, i_h, :], K_shared)
if is_causal:
for k, j in T.Parallel(G, BS):
acc_s[k, j] = T.if_then_else(i_t >= (i_s + j), 0,
-T.infinity(acc_s.dtype))
acc_s[k, j] = T.if_then_else(i_t >= (i_s + j), 0, -T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
......@@ -138,7 +138,7 @@ def tilelang_kernel_fwd(
acc_o[k, j] *= scores_scale[k]
# V * softmax(Q * K)
T.copy(V[i_b, i_s:i_s + BS, i_h, i_v * BV:(i_v + 1) * BV], V_shared)
T.copy(V[i_b, i_s : i_s + BS, i_h, i_v * BV : (i_v + 1) * BV], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(G, BV):
......@@ -146,18 +146,20 @@ def tilelang_kernel_fwd(
T.copy(acc_o, O_shared)
T.copy(
O_shared,
O_slc[i_b, i_t, i_h * G:(i_h + 1) * G, i_v * BV:(i_v + 1) * BV],
O_slc[i_b, i_t, i_h * G : (i_h + 1) * G, i_v * BV : (i_v + 1) * BV],
)
for i in T.Parallel(G):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(logsum, LSE_slc[i_b, i_t, i_h * G:(i_h + 1) * G])
T.copy(logsum, LSE_slc[i_b, i_t, i_h * G : (i_h + 1) * G])
return native_sparse_attention
@tilelang.jit(pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}
)
def tilelang_kernel_bwd_dkv(
batch,
heads,
......@@ -172,7 +174,7 @@ def tilelang_kernel_bwd_dkv(
accum_dtype="float",
):
if scale is None:
sm_scale = (1.0 / dim)**0.5
sm_scale = (1.0 / dim) ** 0.5
else:
sm_scale = scale
......@@ -207,15 +209,15 @@ def tilelang_kernel_bwd_dkv(
@T.prim_func
def flash_bwd_dkv(
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(k_shape, dtype),
V: T.Tensor(v_shape, dtype),
LSE_slc: T.Tensor(lse_slc_shape, accum_dtype),
Delta_slc: T.Tensor(delta_slc_shape, accum_dtype),
DO_slc: T.Tensor(do_slc_shape, dtype),
DK: T.Tensor(dk_shape, dtype),
DV: T.Tensor(dv_shape, dtype),
BlockMask: T.Tensor(block_mask_shape, "int32"),
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(k_shape, dtype),
V: T.Tensor(v_shape, dtype),
LSE_slc: T.Tensor(lse_slc_shape, accum_dtype),
Delta_slc: T.Tensor(delta_slc_shape, accum_dtype),
DO_slc: T.Tensor(do_slc_shape, dtype),
DK: T.Tensor(dk_shape, dtype),
DV: T.Tensor(dv_shape, dtype),
BlockMask: T.Tensor(block_mask_shape, "int32"),
):
with T.Kernel(NV, NS, B * H, threads=num_threads) as (i_v, i_s, i_bh):
K_shared = T.alloc_shared([BS, BK], dtype)
......@@ -238,31 +240,33 @@ def tilelang_kernel_bwd_dkv(
i_b, i_h = i_bh // H, i_bh % H
T.copy(K[i_b, i_s * BS:(i_s + 1) * BS, i_h, :BK], K_shared)
T.copy(V[i_b, i_s * BS:(i_s + 1) * BS, i_h, :BV], V_shared)
T.copy(K[i_b, i_s * BS : (i_s + 1) * BS, i_h, :BK], K_shared)
T.copy(V[i_b, i_s * BS : (i_s + 1) * BS, i_h, :BV], V_shared)
# [BS, BK]
T.clear(dk)
# [BS, BV]
T.clear(dv)
T.annotate_layout({
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
dv_shared: tilelang.layout.make_swizzled_layout(dv_shared),
dk_shared: tilelang.layout.make_swizzled_layout(dk_shared),
})
T.annotate_layout(
{
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
dv_shared: tilelang.layout.make_swizzled_layout(dv_shared),
dk_shared: tilelang.layout.make_swizzled_layout(dk_shared),
}
)
loop_st = i_s * BS
loop_ed = seq_len
for i in T.Pipelined(
start=loop_st,
stop=loop_ed,
num_stages=0,
start=loop_st,
stop=loop_ed,
num_stages=0,
):
b_m_slc = BlockMask[i_b, i, i_h, i_s]
if b_m_slc != 0:
# [G, BK]
T.copy(Q[i_b, i, i_h * G:(i_h + 1) * G, :BK], Q_shared)
T.copy(Q[i_b, i, i_h * G : (i_h + 1) * G, :BK], Q_shared)
T.clear(qkT)
# [BS, BK] @ [G, BK] -> [BS, G]
T.gemm(
......@@ -273,7 +277,7 @@ def tilelang_kernel_bwd_dkv(
policy=T.GemmWarpPolicy.FullRow,
)
# [G]
T.copy(LSE_slc[i_b, i, i_h * G:(i_h + 1) * G], lse_shared)
T.copy(LSE_slc[i_b, i, i_h * G : (i_h + 1) * G], lse_shared)
for _i, _j in T.Parallel(BS, G):
qkT[_i, _j] = T.exp2(qkT[_i, _j] * scale - lse_shared[_j])
......@@ -282,7 +286,7 @@ def tilelang_kernel_bwd_dkv(
qkT[_i, _j] = T.if_then_else(i >= (i_s * BS + _i), qkT[_i, _j], 0)
# [G, BV]
T.copy(DO_slc[i_b, i, i_h * G:(i_h + 1) * G, :BV], do)
T.copy(DO_slc[i_b, i, i_h * G : (i_h + 1) * G, :BV], do)
T.clear(dsT)
# [BS, BV] @ [G, BV] -> [BS, G]
T.gemm(
......@@ -296,7 +300,7 @@ def tilelang_kernel_bwd_dkv(
# [BS, G] @ [G, BV] -> [BS, BV]
T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow)
# [G]
T.copy(Delta_slc[i_b, i, i_h * G:(i_h + 1) * G], delta)
T.copy(Delta_slc[i_b, i, i_h * G : (i_h + 1) * G], delta)
for i, j in T.Parallel(BS, G):
dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale
......@@ -305,8 +309,8 @@ def tilelang_kernel_bwd_dkv(
T.copy(dv, dv_shared)
T.copy(dk, dk_shared)
T.copy(dv_shared, DV[i_b, i_s * BS:(i_s + 1) * BS, i_h, :BV])
T.copy(dk_shared, DK[i_v, i_b, i_s * BS:(i_s + 1) * BS, i_h, :BK])
T.copy(dv_shared, DV[i_b, i_s * BS : (i_s + 1) * BS, i_h, :BV])
T.copy(dk_shared, DK[i_v, i_b, i_s * BS : (i_s + 1) * BS, i_h, :BK])
return flash_bwd_dkv
......@@ -321,9 +325,11 @@ def make_dq_layout(dQ):
)
@tilelang.jit(pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
}
)
def tilelang_kernel_bwd_dqkv(
batch,
heads,
......@@ -338,7 +344,7 @@ def tilelang_kernel_bwd_dqkv(
accum_dtype="float",
):
if scale is None:
sm_scale = (1.0 / dim)**0.5
sm_scale = (1.0 / dim) ** 0.5
else:
sm_scale = scale
......@@ -373,16 +379,16 @@ def tilelang_kernel_bwd_dqkv(
@T.prim_func
def flash_bwd_dqkv(
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(k_shape, dtype),
V: T.Tensor(v_shape, dtype),
LSE_slc: T.Tensor(lse_slc_shape, accum_dtype),
Delta_slc: T.Tensor(delta_slc_shape, accum_dtype),
DO_slc: T.Tensor(do_slc_shape, dtype),
DQ: T.Tensor(dq_shape, dtype),
DK: T.Tensor(dk_shape, dtype),
DV: T.Tensor(dv_shape, dtype),
BlockMask: T.Tensor(block_mask_shape, "int32"),
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(k_shape, dtype),
V: T.Tensor(v_shape, dtype),
LSE_slc: T.Tensor(lse_slc_shape, accum_dtype),
Delta_slc: T.Tensor(delta_slc_shape, accum_dtype),
DO_slc: T.Tensor(do_slc_shape, dtype),
DQ: T.Tensor(dq_shape, dtype),
DK: T.Tensor(dk_shape, dtype),
DV: T.Tensor(dv_shape, dtype),
BlockMask: T.Tensor(block_mask_shape, "int32"),
):
with T.Kernel(NV, NS, B * H, threads=num_threads) as (i_v, i_s, i_bh):
K_shared = T.alloc_shared([BS, BK], dtype)
......@@ -406,31 +412,33 @@ def tilelang_kernel_bwd_dqkv(
i_b, i_h = i_bh // H, i_bh % H
T.copy(K[i_b, i_s * BS:(i_s + 1) * BS, i_h, :BK], K_shared)
T.copy(V[i_b, i_s * BS:(i_s + 1) * BS, i_h, :BV], V_shared)
T.copy(K[i_b, i_s * BS : (i_s + 1) * BS, i_h, :BK], K_shared)
T.copy(V[i_b, i_s * BS : (i_s + 1) * BS, i_h, :BV], V_shared)
# [BS, BK]
T.clear(dk)
# [BS, BV]
T.clear(dv)
T.annotate_layout({
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
dv_shared: tilelang.layout.make_swizzled_layout(dv_shared),
dk_shared: tilelang.layout.make_swizzled_layout(dk_shared),
})
T.annotate_layout(
{
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
dv_shared: tilelang.layout.make_swizzled_layout(dv_shared),
dk_shared: tilelang.layout.make_swizzled_layout(dk_shared),
}
)
loop_st = i_s * BS
loop_ed = seq_len
for i in T.Pipelined(
start=loop_st,
stop=loop_ed,
num_stages=0,
start=loop_st,
stop=loop_ed,
num_stages=0,
):
b_m_slc = BlockMask[i_b, i, i_h, i_s]
if b_m_slc != 0:
# [G, BK]
T.copy(Q[i_b, i, i_h * G:(i_h + 1) * G, :BK], Q_shared)
T.copy(Q[i_b, i, i_h * G : (i_h + 1) * G, :BK], Q_shared)
T.clear(qkT)
# [BS, BK] @ [G, BK] -> [BS, G]
T.gemm(
......@@ -441,7 +449,7 @@ def tilelang_kernel_bwd_dqkv(
policy=T.GemmWarpPolicy.FullRow,
)
# [G]
T.copy(LSE_slc[i_b, i, i_h * G:(i_h + 1) * G], lse_shared)
T.copy(LSE_slc[i_b, i, i_h * G : (i_h + 1) * G], lse_shared)
for _i, _j in T.Parallel(BS, G):
qkT[_i, _j] = T.exp2(qkT[_i, _j] * scale - lse_shared[_j])
......@@ -450,7 +458,7 @@ def tilelang_kernel_bwd_dqkv(
qkT[_i, _j] = T.if_then_else(i >= (i_s * BS + _i), qkT[_i, _j], 0)
# [G, BV]
T.copy(DO_slc[i_b, i, i_h * G:(i_h + 1) * G, :BV], do)
T.copy(DO_slc[i_b, i, i_h * G : (i_h + 1) * G, :BV], do)
T.clear(dsT)
# [BS, BV] @ [G, BV] -> [BS, G]
T.gemm(
......@@ -464,7 +472,7 @@ def tilelang_kernel_bwd_dqkv(
# [BS, G] @ [G, BV] -> [BS, BV]
T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow)
# [G]
T.copy(Delta_slc[i_b, i, i_h * G:(i_h + 1) * G], delta)
T.copy(Delta_slc[i_b, i, i_h * G : (i_h + 1) * G], delta)
for _i, _j in T.Parallel(BS, G):
dsT_cast[_i, _j] = qkT[_i, _j] * (dsT[_i, _j] - delta[_j]) * sm_scale
......@@ -480,16 +488,18 @@ def tilelang_kernel_bwd_dqkv(
T.copy(dv, dv_shared)
T.copy(dk, dk_shared)
T.copy(dv_shared, DV[i_b, i_s * BS:(i_s + 1) * BS, i_h, :BV])
T.copy(dk_shared, DK[i_v, i_b, i_s * BS:(i_s + 1) * BS, i_h, :BK])
T.copy(dv_shared, DV[i_b, i_s * BS : (i_s + 1) * BS, i_h, :BV])
T.copy(dk_shared, DK[i_v, i_b, i_s * BS : (i_s + 1) * BS, i_h, :BK])
return flash_bwd_dqkv
@tilelang.jit(
out_idx=[2], pass_configs={
out_idx=[2],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
},
)
def tilelang_kernel_preprocess(
batch,
heads,
......@@ -505,9 +515,9 @@ def tilelang_kernel_preprocess(
@T.prim_func
def flash_bwd_prep(
O: T.Tensor(shape, dtype), # type: ignore
dO: T.Tensor(shape, dtype), # type: ignore
Delta: T.Tensor([batch, seq_len, heads], accum_dtype), # type: ignore
O: T.Tensor(shape, dtype), # type: ignore
dO: T.Tensor(shape, dtype), # type: ignore
Delta: T.Tensor([batch, seq_len, heads], accum_dtype), # type: ignore
):
with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz):
o = T.alloc_fragment([blk, blk], dtype)
......@@ -516,20 +526,22 @@ def tilelang_kernel_preprocess(
delta = T.alloc_fragment([blk], accum_dtype)
T.clear(acc)
for k in range(T.ceildiv(dim, blk)):
T.copy(O[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], o)
T.copy(dO[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], do)
T.copy(O[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], o)
T.copy(dO[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], do)
for i, j in T.Parallel(blk, blk):
acc[i, j] += o[i, j] * do[i, j]
T.reduce_sum(acc, delta, 1)
T.copy(delta, Delta[bz, by * blk:(by + 1) * blk, bx])
T.copy(delta, Delta[bz, by * blk : (by + 1) * blk, bx])
return flash_bwd_prep
@tilelang.jit(
out_idx=[2], pass_configs={
out_idx=[2],
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
},
)
def tilelang_kernel_block_mask(
batch,
heads,
......@@ -551,9 +563,9 @@ def tilelang_kernel_block_mask(
@T.prim_func
def flash_bwd_block_mask(
BlockIndices: T.Tensor(block_indices_shape, dtype), # type: ignore
BlockCounts: T.Tensor(block_counts_shape, dtype), # type: ignore
BlockMask: T.Tensor(block_mask_shape, dtype), # type: ignore
BlockIndices: T.Tensor(block_indices_shape, dtype), # type: ignore
BlockCounts: T.Tensor(block_counts_shape, dtype), # type: ignore
BlockMask: T.Tensor(block_mask_shape, dtype), # type: ignore
):
with T.Kernel(seq_len, batch, heads * S) as (bx, by, bz):
i_t, i_b, i_hs = bx, by, bz
......@@ -603,9 +615,7 @@ def parallel_nsa_bwd(
dk = torch.empty(NV, *k.shape, dtype=k.dtype, device=q.device)
dv = torch.empty(v.shape, dtype=v.dtype, device=q.device)
block_mask = tilelang_kernel_block_mask(B, H, T, S,
BS)(block_indices.to(torch.int32),
block_counts.to(torch.int32)).to(torch.bool)
block_mask = tilelang_kernel_block_mask(B, H, T, S, BS)(block_indices.to(torch.int32), block_counts.to(torch.int32)).to(torch.bool)
fused_qkv_bwd_kernel = tilelang_kernel_bwd_dqkv(
batch=B,
......@@ -618,8 +628,7 @@ def parallel_nsa_bwd(
selected_blocks=S,
scale=scale,
)
fused_qkv_bwd_kernel(q, k, v, lse_slc, delta_slc, do_slc, dq, dk, dv,
block_mask.to(torch.int32))
fused_qkv_bwd_kernel(q, k, v, lse_slc, delta_slc, do_slc, dq, dk, dv, block_mask.to(torch.int32))
dq = dq.sum(0)
dk = dk.sum(0)
......@@ -628,7 +637,6 @@ def parallel_nsa_bwd(
@torch.compile
class ParallelNSAFunction(torch.autograd.Function):
@staticmethod
@contiguous
@autocast_custom_fwd
......@@ -773,23 +781,21 @@ def parallel_nsa(
Outputs of shape `[B, SEQLEN, HQ, V]` if `head_first=False` else `[B, HQ, SEQLEN, V]`.
"""
if scale is None:
scale = k.shape[-1]**-0.5
scale = k.shape[-1] ** -0.5
if cu_seqlens is not None:
assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided"
if head_first:
q, k, v, block_indices = map(lambda x: rearrange(x, "b h t d -> b t h d"),
(q, k, v, block_indices))
q, k, v, block_indices = map(lambda x: rearrange(x, "b h t d -> b t h d"), (q, k, v, block_indices))
g_slc, g_swa = map(lambda x: rearrange(x, "b h t -> b t h"), (g_slc, g_swa))
if isinstance(block_counts, torch.Tensor):
block_counts = rearrange(block_counts, "b h t -> b t h")
assert (q.shape[2] % (k.shape[2] * 16) == 0), "Group size must be a multiple of 16 in NSA"
assert q.shape[2] % (k.shape[2] * 16) == 0, "Group size must be a multiple of 16 in NSA"
if isinstance(block_counts, int):
block_indices = block_indices[:, :, :, :block_counts]
block_counts = None
o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size,
window_size, scale, cu_seqlens)
o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, window_size, scale, cu_seqlens)
if window_size > 0:
o = torch.addcmul(o_slc * g_slc.unsqueeze(-1), o_swa, g_swa.unsqueeze(-1))
else:
......@@ -814,7 +820,7 @@ if __name__ == "__main__":
for t in range(T):
for h in range(H):
i_i = torch.randperm(max(1, (t // block_size)))[:S]
block_indices[b, t, h, :len(i_i)] = i_i
block_indices[b, t, h, : len(i_i)] = i_i
block_indices = block_indices.sort(-1)[0]
block_counts = torch.randint(1, S + 1, (B, T, H), device="cuda")
......
......@@ -16,7 +16,8 @@ tilelang.testing.set_random_seed(42)
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
},
)
def native_sparse_attention(
batch,
heads,
......@@ -25,10 +26,10 @@ def native_sparse_attention(
scale=None,
block_size=64, # Tile size for attention computation
groups=1, # Grouped query attention (GQA) groups
selected_blocks=16 # Number of blocks to select per attention head
selected_blocks=16, # Number of blocks to select per attention head
):
if scale is None:
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
head_kv = heads // groups
# Modified shapes for inference (q has seq_len=1)a
q_shape = [batch, 1, heads, dim] # Changed seq_len to 1
......@@ -53,12 +54,11 @@ def native_sparse_attention(
@T.prim_func
def native_sparse_attention(
Q: T.Tensor(q_shape, dtype), # [batch, 1, heads, dim]
K: T.Tensor(kv_shape, dtype), # [batch, seq_len, head_kv, dim]
V: T.Tensor(kv_shape, dtype), # Same shape as K
BlockIndices: T.Tensor(block_indices_shape,
block_indices_dtype), # Selected block indices
Output: T.Tensor(q_shape, dtype), # Output attention tensor
Q: T.Tensor(q_shape, dtype), # [batch, 1, heads, dim]
K: T.Tensor(kv_shape, dtype), # [batch, seq_len, head_kv, dim]
V: T.Tensor(kv_shape, dtype), # Same shape as K
BlockIndices: T.Tensor(block_indices_shape, block_indices_dtype), # Selected block indices
Output: T.Tensor(q_shape, dtype), # Output attention tensor
):
with T.Kernel(1, NV, batch * head_kv, threads=threads) as (bx, by, bz):
# Shared memory allocations for tile storage
......@@ -82,7 +82,7 @@ def native_sparse_attention(
NS = S
# Copy Q for the single position
T.copy(Q[i_b, 0, i_h * G:(i_h + 1) * G, :], Q_shared) # Changed i_t to 0
T.copy(Q[i_b, 0, i_h * G : (i_h + 1) * G, :], Q_shared) # Changed i_t to 0
T.fill(acc_o, 0)
T.fill(logsum, 0)
......@@ -93,16 +93,11 @@ def native_sparse_attention(
i_s = BlockIndices[i_b, 0, i_h, i] * BS # Get block offset
if i_s >= 0: # Skip invalid/padding blocks
# Load current key block to shared memory
T.copy(K[i_b, i_s:i_s + BS, i_h, :], K_shared)
T.copy(K[i_b, i_s : i_s + BS, i_h, :], K_shared)
# Compute QK^T attention scores
T.clear(acc_s)
T.gemm(
Q_shared,
K_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
# Online softmax with numerical stability
# 1. Compute max for scaling
......@@ -122,15 +117,14 @@ def native_sparse_attention(
T.copy(acc_s, acc_s_cast)
# Accumulate attention-weighted values
T.copy(V[i_b, i_s:i_s + BS, i_h, i_v * BV:(i_v + 1) * BV], V_shared)
T.copy(V[i_b, i_s : i_s + BS, i_h, i_v * BV : (i_v + 1) * BV], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
# Final normalization and output
for i, j in T.Parallel(G, BV):
acc_o[i, j] /= logsum[i] # Normalize by logsum
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[i_b, 0, i_h * G:(i_h + 1) * G,
i_v * BV:(i_v + 1) * BV]) # Changed i_t to 0
T.copy(O_shared, Output[i_b, 0, i_h * G : (i_h + 1) * G, i_v * BV : (i_v + 1) * BV]) # Changed i_t to 0
return native_sparse_attention
......@@ -149,21 +143,21 @@ def main():
selected_blocks=S,
)
Q = torch.randn((B, SEQ_LEN_Q, HQ, D), dtype=dtype, device='cuda').requires_grad_(True)
K = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device='cuda').requires_grad_(True)
V = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device='cuda').requires_grad_(True)
Q = torch.randn((B, SEQ_LEN_Q, HQ, D), dtype=dtype, device="cuda").requires_grad_(True)
K = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device="cuda").requires_grad_(True)
V = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device="cuda").requires_grad_(True)
mask = torch.randint(0, 2, (B, SEQ_LEN, groups), device='cuda')
DO = torch.randn((B, SEQ_LEN_Q, HQ, D), dtype=dtype, device='cuda')
mask = torch.randint(0, 2, (B, SEQ_LEN, groups), device="cuda")
DO = torch.randn((B, SEQ_LEN_Q, HQ, D), dtype=dtype, device="cuda")
block_indices = torch.full((B, SEQ_LEN_Q, H, S), SEQ_LEN, dtype=torch.long, device='cuda')
block_indices = torch.full((B, SEQ_LEN_Q, H, S), SEQ_LEN, dtype=torch.long, device="cuda")
for b in range(B):
for t in range(SEQ_LEN_Q):
for h in range(H):
i_i = torch.randperm(max(1, (t // block_size)))[:S]
block_indices[b, t, h, :len(i_i)] = i_i
block_indices[b, t, h, : len(i_i)] = i_i
block_indices = block_indices.sort(-1)[0]
block_counts = torch.randint(1, S + 1, (B, SEQ_LEN_Q, H), device='cuda')
block_counts = torch.randint(1, S + 1, (B, SEQ_LEN_Q, H), device="cuda")
out = kernel(Q, K, V, block_indices.to(torch.int32))
......
......@@ -14,18 +14,11 @@ tilelang.testing.set_random_seed(0)
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
def native_sparse_attention(batch,
heads,
seq_len,
dim,
is_causal,
scale=None,
block_size=64,
groups=1,
selected_blocks=16):
},
)
def native_sparse_attention(batch, heads, seq_len, dim, is_causal, scale=None, block_size=64, groups=1, selected_blocks=16):
if scale is None:
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
else:
scale = scale * 1.44269504 # log2(e)
......@@ -52,11 +45,11 @@ def native_sparse_attention(batch,
@T.prim_func
def native_sparse_attention(
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype),
BlockIndices: T.Tensor(block_indices_shape, block_indices_dtype),
Output: T.Tensor(q_shape, dtype),
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype),
BlockIndices: T.Tensor(block_indices_shape, block_indices_dtype),
Output: T.Tensor(q_shape, dtype),
):
with T.Kernel(seq_len, NV, batch * head_kv, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([G, BK], dtype)
......@@ -77,7 +70,7 @@ def native_sparse_attention(batch,
i_b, i_h = i_bh // head_kv, i_bh % head_kv
NS = S
T.copy(Q[i_b, i_t, i_h * G:(i_h + 1) * G, :], Q_shared)
T.copy(Q[i_b, i_t, i_h * G : (i_h + 1) * G, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
......@@ -87,21 +80,15 @@ def native_sparse_attention(batch,
i_s = BlockIndices[i_b, i_t, i_h, i] * BS
if i_s <= i_t and i_s >= 0:
# [BS, BK]
T.copy(K[i_b, i_s:i_s + BS, i_h, :], K_shared)
T.copy(K[i_b, i_s : i_s + BS, i_h, :], K_shared)
if is_causal:
for i, j in T.Parallel(G, BS):
acc_s[i, j] = T.if_then_else(i_t >= (i_s + j), 0,
-T.infinity(acc_s.dtype))
acc_s[i, j] = T.if_then_else(i_t >= (i_s + j), 0, -T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
T.gemm(
Q_shared,
K_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
# Softmax
T.copy(scores_max, scores_max_prev)
......@@ -121,13 +108,13 @@ def native_sparse_attention(batch,
acc_o[i, j] *= scores_scale[i]
# V * softmax(Q * K)
T.copy(V[i_b, i_s:i_s + BS, i_h, i_v * BV:(i_v + 1) * BV], V_shared)
T.copy(V[i_b, i_s : i_s + BS, i_h, i_v * BV : (i_v + 1) * BV], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(G, BV):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[i_b, i_t, i_h * G:(i_h + 1) * G, i_v * BV:(i_v + 1) * BV])
T.copy(O_shared, Output[i_b, i_t, i_h * G : (i_h + 1) * G, i_v * BV : (i_v + 1) * BV])
return native_sparse_attention
......@@ -148,20 +135,20 @@ def main():
)
print(kernel.get_kernel_source())
torch.random.manual_seed(0)
Q = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device='cuda').requires_grad_(True)
K = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device='cuda').requires_grad_(True)
V = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device='cuda').requires_grad_(True)
g_slc = torch.ones((B, SEQ_LEN, HQ), dtype=dtype, device='cuda').requires_grad_(True)
g_swa = torch.ones((B, SEQ_LEN, HQ), dtype=dtype, device='cuda').requires_grad_(True)
DO = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device='cuda')
block_indices = torch.full((B, SEQ_LEN, H, S), SEQ_LEN, dtype=torch.long, device='cuda')
block_counts = torch.zeros((B, SEQ_LEN, H), dtype=torch.long, device='cuda')
Q = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device="cuda").requires_grad_(True)
K = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device="cuda").requires_grad_(True)
V = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device="cuda").requires_grad_(True)
g_slc = torch.ones((B, SEQ_LEN, HQ), dtype=dtype, device="cuda").requires_grad_(True)
g_swa = torch.ones((B, SEQ_LEN, HQ), dtype=dtype, device="cuda").requires_grad_(True)
DO = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device="cuda")
block_indices = torch.full((B, SEQ_LEN, H, S), SEQ_LEN, dtype=torch.long, device="cuda")
block_counts = torch.zeros((B, SEQ_LEN, H), dtype=torch.long, device="cuda")
for b in range(B):
for t in range(SEQ_LEN):
for h in range(H):
i_i = torch.randperm(max(1, (t // block_size)))[:S]
block_indices[b, t, h, :len(i_i)] = i_i
block_indices[b, t, h, : len(i_i)] = i_i
block_counts[b, t, h] = (block_indices[b, t, h] != SEQ_LEN).sum().item()
block_indices = block_indices.sort(-1)[0]
......
......@@ -8,6 +8,7 @@ from tilelang import language as T
import tilelang.testing
import fla
if parse(fla.__version__) < parse("0.2.1"):
from fla.ops.common.utils import prepare_token_indices
else:
......@@ -21,18 +22,11 @@ from einops import rearrange
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
def native_sparse_attention_varlen(batch,
heads,
c_seq_len,
dim,
is_causal,
scale=None,
block_size=64,
groups=1,
selected_blocks=16):
}
)
def native_sparse_attention_varlen(batch, heads, c_seq_len, dim, is_causal, scale=None, block_size=64, groups=1, selected_blocks=16):
if scale is None:
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
head_kv = heads // groups
q_shape = [c_seq_len, heads, dim]
kv_shape = [c_seq_len, head_kv, dim]
......@@ -66,14 +60,14 @@ def native_sparse_attention_varlen(batch,
@T.prim_func
def native_sparse_attention_varlen(
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype),
O_slc: T.Tensor(o_slc_shape, dtype),
BlockIndices: T.Tensor(block_indices_shape, block_indices_dtype),
BlockCounts: T.Tensor(block_counts_shape, block_counts_dtype),
Offsets: T.Tensor(offsets_shape, offsets_dtype),
TokenIndices: T.Tensor(token_indices_shape, token_indices_dtype),
Q: T.Tensor(q_shape, dtype),
K: T.Tensor(kv_shape, dtype),
V: T.Tensor(kv_shape, dtype),
O_slc: T.Tensor(o_slc_shape, dtype),
BlockIndices: T.Tensor(block_indices_shape, block_indices_dtype),
BlockCounts: T.Tensor(block_counts_shape, block_counts_dtype),
Offsets: T.Tensor(offsets_shape, offsets_dtype),
TokenIndices: T.Tensor(token_indices_shape, token_indices_dtype),
):
with T.Kernel(c_seq_len, NV, batch * head_kv, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([G, BK], dtype)
......@@ -100,7 +94,7 @@ def native_sparse_attention_varlen(batch,
current_seq_len = eos - bos
NS = BlockCounts[i_t, i_h]
T.copy(Q[bos + i_t, i_h * G:(i_h + 1) * G, :BK], Q_shared)
T.copy(Q[bos + i_t, i_h * G : (i_h + 1) * G, :BK], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
......@@ -112,21 +106,15 @@ def native_sparse_attention_varlen(batch,
# [BS, BK]
# Lei: may have some padding issues
# we should learn from mha varlen templates to handle this
T.copy(K[bos + i_s:bos + i_s + BS, i_h, :BK], K_shared)
T.copy(K[bos + i_s : bos + i_s + BS, i_h, :BK], K_shared)
if is_causal:
for i, j in T.Parallel(G, BS):
acc_s[i, j] = T.if_then_else(i_t >= (i_s + j), 0,
-T.infinity(acc_s.dtype))
acc_s[i, j] = T.if_then_else(i_t >= (i_s + j), 0, -T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
T.gemm(
Q_shared,
K_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
# Softmax
T.copy(scores_max, scores_max_prev)
......@@ -146,13 +134,13 @@ def native_sparse_attention_varlen(batch,
acc_o[i, j] *= scores_scale[i]
# V * softmax(Q * K)
T.copy(V[bos + i_s:bos + i_s + BS, i_h, i_v * BV:(i_v + 1) * BV], V_shared)
T.copy(V[bos + i_s : bos + i_s + BS, i_h, i_v * BV : (i_v + 1) * BV], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(G, BV):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, O_slc[bos + i_t, i_h * G:(i_h + 1) * G, i_v * BV:(i_v + 1) * BV])
T.copy(O_shared, O_slc[bos + i_t, i_h * G : (i_h + 1) * G, i_v * BV : (i_v + 1) * BV])
return native_sparse_attention_varlen
......@@ -190,17 +178,20 @@ def parallel_nsa_fwd(
o_slc = torch.empty(B, C_SEQ_LEN, HQ, V, dtype=v.dtype, device=q.device)
kernel(
q.view(C_SEQ_LEN, HQ, D), k.view(C_SEQ_LEN, H, D), v.view(C_SEQ_LEN, H, D),
q.view(C_SEQ_LEN, HQ, D),
k.view(C_SEQ_LEN, H, D),
v.view(C_SEQ_LEN, H, D),
o_slc.view(C_SEQ_LEN, HQ, V),
block_indices.to(torch.int32).view(C_SEQ_LEN, H, S),
block_counts.to(torch.int32).view(C_SEQ_LEN, H), offsets.to(torch.int32),
token_indices.to(torch.int32))
block_counts.to(torch.int32).view(C_SEQ_LEN, H),
offsets.to(torch.int32),
token_indices.to(torch.int32),
)
return o_slc
@torch.compile
class ParallelNSAFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, block_indices, block_counts, block_size, window_size, scale, offsets):
ctx.dtype = q.dtype
......@@ -221,22 +212,25 @@ class ParallelNSAFunction(torch.autograd.Function):
window_size=window_size,
scale=scale,
offsets=offsets,
token_indices=token_indices)
token_indices=token_indices,
)
return o_slc.to(q.dtype)
def parallel_nsa(q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g_slc: torch.Tensor,
g_swa: torch.Tensor,
block_indices: torch.LongTensor,
block_counts: Optional[Union[torch.LongTensor, int]] = None,
block_size: int = 64,
window_size: int = 0,
scale: Optional[float] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
head_first: bool = False) -> torch.Tensor:
def parallel_nsa(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g_slc: torch.Tensor,
g_swa: torch.Tensor,
block_indices: torch.LongTensor,
block_counts: Optional[Union[torch.LongTensor, int]] = None,
block_size: int = 64,
window_size: int = 0,
scale: Optional[float] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
head_first: bool = False,
) -> torch.Tensor:
r"""
Args:
q (torch.Tensor):
......@@ -276,29 +270,27 @@ def parallel_nsa(q: torch.Tensor,
Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`.
"""
if scale is None:
scale = k.shape[-1]**-0.5
scale = k.shape[-1] ** -0.5
if cu_seqlens is not None:
assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided"
if head_first:
q, k, v, block_indices = map(lambda x: rearrange(x, 'b h t d -> b t h d'),
(q, k, v, block_indices))
g_slc, g_swa = map(lambda x: rearrange(x, 'b h t -> b t h'), (g_slc, g_swa))
q, k, v, block_indices = map(lambda x: rearrange(x, "b h t d -> b t h d"), (q, k, v, block_indices))
g_slc, g_swa = map(lambda x: rearrange(x, "b h t -> b t h"), (g_slc, g_swa))
if isinstance(block_counts, torch.Tensor):
block_counts = rearrange(block_counts, 'b h t -> b t h')
block_counts = rearrange(block_counts, "b h t -> b t h")
assert q.shape[2] % (k.shape[2] * 16) == 0, "Group size must be a multiple of 16 in NSA"
if isinstance(block_counts, int):
block_indices = block_indices[:, :, :, :block_counts]
block_counts = None
o_slc = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, window_size,
scale, cu_seqlens)
o_slc = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, window_size, scale, cu_seqlens)
if window_size > 0:
assert False, "Window size is not supported yet"
else:
o = o_slc * g_slc.unsqueeze(-1)
if head_first:
o = rearrange(o, 'b t h d -> b h t d')
o = rearrange(o, "b t h d -> b h t d")
return o
......@@ -306,41 +298,57 @@ if __name__ == "__main__":
N, C_SEQ_LEN, H, HQ, D, S, block_size, dtype = 2, 64, 1, 16, 64, 1, 32, torch.float16
torch.manual_seed(42)
# randomly split the sequence into N segments
offsets = torch.cat([
torch.tensor([0], dtype=torch.long),
torch.arange(16, C_SEQ_LEN)[torch.randperm(C_SEQ_LEN - 1)[:N - 1]],
torch.tensor([C_SEQ_LEN], dtype=torch.long)
], 0).cuda().sort()[0]
offsets = (
torch.cat(
[
torch.tensor([0], dtype=torch.long),
torch.arange(16, C_SEQ_LEN)[torch.randperm(C_SEQ_LEN - 1)[: N - 1]],
torch.tensor([C_SEQ_LEN], dtype=torch.long),
],
0,
)
.cuda()
.sort()[0]
)
# seq-first required for inputs with variable lengths
perm_q = torch.randperm(C_SEQ_LEN, device='cuda')
perm_k = torch.randperm(C_SEQ_LEN, device='cuda')
perm_v = torch.randperm(C_SEQ_LEN, device='cuda')
q = torch.linspace(
0, 1, steps=C_SEQ_LEN, dtype=dtype,
device='cuda')[perm_q].view(1, C_SEQ_LEN, 1, 1).expand(1, C_SEQ_LEN, HQ,
D).clone().requires_grad_(True)
k = torch.linspace(
0, 1, steps=C_SEQ_LEN, dtype=dtype,
device='cuda')[perm_k].view(1, C_SEQ_LEN, 1, 1).expand(1, C_SEQ_LEN, H,
D).clone().requires_grad_(True)
v = torch.linspace(
0, 1, steps=C_SEQ_LEN, dtype=dtype,
device='cuda')[perm_v].view(1, C_SEQ_LEN, 1, 1).expand(1, C_SEQ_LEN, H,
D).clone().requires_grad_(True)
g_slc = torch.rand((1, C_SEQ_LEN, HQ), dtype=dtype, device='cuda').requires_grad_(True)
g_swa = torch.rand((1, C_SEQ_LEN, HQ), dtype=dtype, device='cuda').requires_grad_(True)
do = torch.randn((1, C_SEQ_LEN, HQ, D), dtype=dtype, device='cuda')
perm_q = torch.randperm(C_SEQ_LEN, device="cuda")
perm_k = torch.randperm(C_SEQ_LEN, device="cuda")
perm_v = torch.randperm(C_SEQ_LEN, device="cuda")
q = (
torch.linspace(0, 1, steps=C_SEQ_LEN, dtype=dtype, device="cuda")[perm_q]
.view(1, C_SEQ_LEN, 1, 1)
.expand(1, C_SEQ_LEN, HQ, D)
.clone()
.requires_grad_(True)
)
k = (
torch.linspace(0, 1, steps=C_SEQ_LEN, dtype=dtype, device="cuda")[perm_k]
.view(1, C_SEQ_LEN, 1, 1)
.expand(1, C_SEQ_LEN, H, D)
.clone()
.requires_grad_(True)
)
v = (
torch.linspace(0, 1, steps=C_SEQ_LEN, dtype=dtype, device="cuda")[perm_v]
.view(1, C_SEQ_LEN, 1, 1)
.expand(1, C_SEQ_LEN, H, D)
.clone()
.requires_grad_(True)
)
g_slc = torch.rand((1, C_SEQ_LEN, HQ), dtype=dtype, device="cuda").requires_grad_(True)
g_swa = torch.rand((1, C_SEQ_LEN, HQ), dtype=dtype, device="cuda").requires_grad_(True)
do = torch.randn((1, C_SEQ_LEN, HQ, D), dtype=dtype, device="cuda")
token_indices = prepare_token_indices(offsets).tolist()
block_indices = torch.full((1, C_SEQ_LEN, H, S), C_SEQ_LEN, dtype=torch.long, device='cuda')
block_indices = torch.full((1, C_SEQ_LEN, H, S), C_SEQ_LEN, dtype=torch.long, device="cuda")
for i in range(C_SEQ_LEN):
_, t = token_indices[i]
for h in range(H):
i_i = torch.randperm(max(1, tilelang.cdiv(t, block_size)))[:S]
block_indices[0, i, h, :len(i_i)] = i_i
block_indices[0, i, h, : len(i_i)] = i_i
block_indices = block_indices.sort(-1)[0]
block_counts = torch.randint(1, S + 1, (1, C_SEQ_LEN, H), device='cuda')
block_counts = torch.randint(1, S + 1, (1, C_SEQ_LEN, H), device="cuda")
ref = naive_nsa(
q=q,
......@@ -351,7 +359,8 @@ if __name__ == "__main__":
block_indices=block_indices,
block_counts=block_counts,
block_size=block_size,
cu_seqlens=offsets)
cu_seqlens=offsets,
)
tri = parallel_nsa(
q=q,
......@@ -362,7 +371,8 @@ if __name__ == "__main__":
block_indices=block_indices,
block_counts=block_counts,
block_size=block_size,
cu_seqlens=offsets)
cu_seqlens=offsets,
)
print("tri", tri)
print("ref", ref)
......
......@@ -8,6 +8,7 @@ import triton
import triton.language as tl
import fla
if parse(fla.__version__) < parse("0.2.1"):
from fla.ops.common.utils import prepare_token_indices
else:
......@@ -17,21 +18,44 @@ from reference import naive_nsa
from einops import rearrange
@triton.heuristics({
'USE_OFFSETS': lambda args: args['offsets'] is not None,
'USE_BLOCK_COUNTS': lambda args: isinstance(args['block_counts'], torch.Tensor),
})
@triton.heuristics(
{
"USE_OFFSETS": lambda args: args["offsets"] is not None,
"USE_BLOCK_COUNTS": lambda args: isinstance(args["block_counts"], torch.Tensor),
}
)
@triton.autotune(
configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1]],
key=['BS', 'BK', 'BV'],
key=["BS", "BK", "BV"],
)
@triton.jit
def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, block_indices,
block_counts, offsets, token_indices, T, H: tl.constexpr,
HQ: tl.constexpr, G: tl.constexpr, K: tl.constexpr, V: tl.constexpr,
S: tl.constexpr, BS: tl.constexpr, WS: tl.constexpr, BK: tl.constexpr,
BV: tl.constexpr, USE_OFFSETS: tl.constexpr,
USE_BLOCK_COUNTS: tl.constexpr):
def parallel_nsa_fwd_kernel(
q,
k,
v,
o_slc,
o_swa,
lse_slc,
lse_swa,
scale,
block_indices,
block_counts,
offsets,
token_indices,
T,
H: tl.constexpr,
HQ: tl.constexpr,
G: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
S: tl.constexpr,
BS: tl.constexpr,
WS: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
USE_OFFSETS: tl.constexpr,
USE_BLOCK_COUNTS: tl.constexpr,
):
i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_b, i_h = i_bh // H, i_bh % H
......@@ -46,20 +70,18 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc
# else:
NS = S
p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK),
(1, 0))
p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0))
# the Q block is kept in the shared memory throughout the whole kernel
# [G, BK]
b_q = tl.load(p_q, boundary_check=(0, 1))
b_q = (b_q * scale).to(b_q.dtype)
p_o_slc = tl.make_block_ptr(o_slc + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV),
(G, BV), (1, 0))
p_o_slc = tl.make_block_ptr(o_slc + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0))
p_lse_slc = lse_slc + (bos + i_t) * HQ + i_h * G + tl.arange(0, G)
# [G, BV]
b_o_slc = tl.zeros([G, BV], dtype=tl.float32)
b_m_slc = tl.full([G], float('-inf'), dtype=tl.float32)
b_m_slc = tl.full([G], float("-inf"), dtype=tl.float32)
b_acc_slc = tl.zeros([G], dtype=tl.float32)
for i in range(NS):
i_s = tl.load(block_indices + i).to(tl.int32) * BS
......@@ -72,7 +94,7 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc
b_v_slc = tl.load(p_v_slc, boundary_check=(0, 1))
# [G, BS]
b_s_slc = tl.dot(b_q, b_k_slc)
b_s_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_slc, float('-inf'))
b_s_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_slc, float("-inf"))
# [G]
b_m_slc, b_mp_slc = tl.maximum(b_m_slc, tl.max(b_s_slc, 1)), b_m_slc
......@@ -92,7 +114,6 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc
class ParallelNSAFunction(torch.autograd.Function):
@staticmethod
@contiguous
@autocast_custom_fwd
......@@ -105,8 +126,7 @@ class ParallelNSAFunction(torch.autograd.Function):
# [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
token_indices = prepare_token_indices(offsets) if offsets is not None else None
o, lse = parallel_nsa_fwd(
q=q, k=k, v=v, block_indices=block_indices, block_size=block_size, scale=scale)
o, lse = parallel_nsa_fwd(q=q, k=k, v=v, block_indices=block_indices, block_size=block_size, scale=scale)
ctx.save_for_backward(q, k, v, o, lse)
ctx.block_indices = block_indices
ctx.block_size = block_size
......@@ -134,7 +154,8 @@ class ParallelNSAFunction(torch.autograd.Function):
window_size=ctx.window_size,
scale=ctx.scale,
offsets=ctx.offsets,
token_indices=ctx.token_indices)
token_indices=ctx.token_indices,
)
return dq.to(q), dk.to(k), dv.to(v), None, None, None, None, None, None, None, None
......@@ -199,37 +220,56 @@ def parallel_nsa_fwd(
return o_slc, lse_slc, o_swa, lse_swa
@triton.heuristics({'USE_OFFSETS': lambda args: args['offsets'] is not None})
@triton.heuristics({"USE_OFFSETS": lambda args: args["offsets"] is not None})
@triton.autotune(
configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8]],
key=['BS', 'BK', 'BV'],
key=["BS", "BK", "BV"],
)
@triton.jit(do_not_specialize=['T'])
def parallel_nsa_bwd_kernel_dkv(q, k, v, lse_slc, lse_swa, delta_slc, delta_swa, do_slc, do_swa, dk,
dv, block_mask, offsets, chunk_indices, scale, T, B: tl.constexpr,
H: tl.constexpr, HQ: tl.constexpr, G: tl.constexpr, K: tl.constexpr,
V: tl.constexpr, M: tl.constexpr, BS: tl.constexpr,
WS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,
USE_OFFSETS: tl.constexpr):
@triton.jit(do_not_specialize=["T"])
def parallel_nsa_bwd_kernel_dkv(
q,
k,
v,
lse_slc,
lse_swa,
delta_slc,
delta_swa,
do_slc,
do_swa,
dk,
dv,
block_mask,
offsets,
chunk_indices,
scale,
T,
B: tl.constexpr,
H: tl.constexpr,
HQ: tl.constexpr,
G: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
M: tl.constexpr,
BS: tl.constexpr,
WS: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
USE_OFFSETS: tl.constexpr,
):
i_v, i_s, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_b, i_h = i_bh // H, i_bh % H
if USE_OFFSETS:
i_n, i_s = tl.load(chunk_indices + i_s * 2).to(tl.int32), tl.load(chunk_indices + i_s * 2 +
1).to(tl.int32)
i_n, i_s = tl.load(chunk_indices + i_s * 2).to(tl.int32), tl.load(chunk_indices + i_s * 2 + 1).to(tl.int32)
bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_s * BS, 0), (BS, BK),
(1, 0))
p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_s * BS, i_v * BV),
(BS, BV), (1, 0))
p_dk = tl.make_block_ptr(dk + (i_v * B * T * H + bos * H + i_h) * K, (T, K), (H * K, 1),
(i_s * BS, 0), (BS, BK), (1, 0))
p_dv = tl.make_block_ptr(dv + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_s * BS, i_v * BV),
(BS, BV), (1, 0))
p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_s * BS, 0), (BS, BK), (1, 0))
p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_s * BS, i_v * BV), (BS, BV), (1, 0))
p_dk = tl.make_block_ptr(dk + (i_v * B * T * H + bos * H + i_h) * K, (T, K), (H * K, 1), (i_s * BS, 0), (BS, BK), (1, 0))
p_dv = tl.make_block_ptr(dv + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_s * BS, i_v * BV), (BS, BV), (1, 0))
# [BS, BK]
b_k = tl.load(p_k, boundary_check=(0, 1))
......@@ -241,14 +281,12 @@ def parallel_nsa_bwd_kernel_dkv(q, k, v, lse_slc, lse_swa, delta_slc, delta_swa,
for i in range(i_s * BS, T):
b_m_slc = tl.load(block_mask + (bos + i) * H * M + i_h * M + i_s)
if b_m_slc:
p_q = tl.make_block_ptr(q + (bos + i) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK),
(1, 0))
p_q = tl.make_block_ptr(q + (bos + i) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0))
# [G, BK]
b_q = tl.load(p_q, boundary_check=(0, 1))
b_q = (b_q * scale).to(b_q.dtype)
p_do_slc = tl.make_block_ptr(do_slc + (bos + i) * HQ * V, (HQ, V), (V, 1),
(i_h * G, i_v * BV), (G, BV), (1, 0))
p_do_slc = tl.make_block_ptr(do_slc + (bos + i) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0))
p_lse_slc = lse_slc + (bos + i) * HQ + i_h * G + tl.arange(0, G)
p_delta_slc = delta_slc + (bos + i) * HQ + i_h * G + tl.arange(0, G)
# [G, BV]
......@@ -272,14 +310,12 @@ def parallel_nsa_bwd_kernel_dkv(q, k, v, lse_slc, lse_swa, delta_slc, delta_swa,
if WS > 0:
o_s = i_s * BS + tl.arange(0, BS)
if max(i_s * BS, i - WS + 1) < min((i_s + 1) * BS, i + 1):
p_q = tl.make_block_ptr(q + (bos + i) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0),
(G, BK), (1, 0))
p_q = tl.make_block_ptr(q + (bos + i) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0))
# [G, BK]
b_q = tl.load(p_q, boundary_check=(0, 1))
b_q = (b_q * scale).to(b_q.dtype)
p_do_swa = tl.make_block_ptr(do_swa + (bos + i) * HQ * V, (HQ, V), (V, 1),
(i_h * G, i_v * BV), (G, BV), (1, 0))
p_do_swa = tl.make_block_ptr(do_swa + (bos + i) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0))
p_lse_swa = lse_swa + (bos + i) * HQ + i_h * G + tl.arange(0, G)
p_delta_swa = delta_swa + (bos + i) * HQ + i_h * G + tl.arange(0, G)
# [G, BV]
......@@ -304,12 +340,19 @@ def parallel_nsa_bwd_kernel_dkv(q, k, v, lse_slc, lse_swa, delta_slc, delta_swa,
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
@triton.heuristics(
{'USE_BLOCK_COUNTS': lambda args: isinstance(args['block_counts'], torch.Tensor)})
@triton.heuristics({"USE_BLOCK_COUNTS": lambda args: isinstance(args["block_counts"], torch.Tensor)})
@triton.jit
def parallel_nsa_kernel_mask(block_indices, block_counts, block_mask, T: tl.constexpr,
H: tl.constexpr, S: tl.constexpr, BS: tl.constexpr, NS: tl.constexpr,
USE_BLOCK_COUNTS: tl.constexpr):
def parallel_nsa_kernel_mask(
block_indices,
block_counts,
block_mask,
T: tl.constexpr,
H: tl.constexpr,
S: tl.constexpr,
BS: tl.constexpr,
NS: tl.constexpr,
USE_BLOCK_COUNTS: tl.constexpr,
):
i_t, i_b, i_hs = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_h, i_s = i_hs // S, i_hs % S
......@@ -320,31 +363,56 @@ def parallel_nsa_kernel_mask(block_indices, block_counts, block_mask, T: tl.cons
b_m = b_i * BS <= i_t
if b_i < NS and b_i >= 0:
tl.store(block_mask + i_b * T * H * NS + i_t * H * NS + i_h * NS + b_i,
b_m.to(block_mask.dtype.element_ty))
tl.store(block_mask + i_b * T * H * NS + i_t * H * NS + i_h * NS + b_i, b_m.to(block_mask.dtype.element_ty))
@triton.heuristics({
'USE_OFFSETS': lambda args: args['offsets'] is not None,
'USE_BLOCK_COUNTS': lambda args: isinstance(args['block_counts'], torch.Tensor)
})
@triton.heuristics(
{
"USE_OFFSETS": lambda args: args["offsets"] is not None,
"USE_BLOCK_COUNTS": lambda args: isinstance(args["block_counts"], torch.Tensor),
}
)
@triton.autotune(
configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8]],
key=['BS', 'BK', 'BV'],
key=["BS", "BK", "BV"],
)
@triton.jit(do_not_specialize=['T'])
def parallel_nsa_bwd_kernel_dq(q, k, v, lse_slc, delta_slc, do_slc, lse_swa, delta_swa, do_swa, dq,
scale, block_indices, block_counts, offsets, token_indices, T,
B: tl.constexpr, H: tl.constexpr, HQ: tl.constexpr, G: tl.constexpr,
K: tl.constexpr, V: tl.constexpr, S: tl.constexpr, BS: tl.constexpr,
WS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,
USE_OFFSETS: tl.constexpr, USE_BLOCK_COUNTS: tl.constexpr):
@triton.jit(do_not_specialize=["T"])
def parallel_nsa_bwd_kernel_dq(
q,
k,
v,
lse_slc,
delta_slc,
do_slc,
lse_swa,
delta_swa,
do_swa,
dq,
scale,
block_indices,
block_counts,
offsets,
token_indices,
T,
B: tl.constexpr,
H: tl.constexpr,
HQ: tl.constexpr,
G: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
S: tl.constexpr,
BS: tl.constexpr,
WS: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
USE_OFFSETS: tl.constexpr,
USE_BLOCK_COUNTS: tl.constexpr,
):
i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_b, i_h = i_bh // H, i_bh % H
if USE_OFFSETS:
i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 +
1).to(tl.int32)
i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + 1).to(tl.int32)
bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
T = eos - bos
else:
......@@ -449,27 +517,49 @@ def parallel_nsa_bwd_kernel_dq(q, k, v, lse_slc, delta_slc, do_slc, lse_swa, del
tl.store(p_dq, (b_dq_slc + b_dq_swa).to(p_dq.dtype.element_ty), boundary_check=(0, 1))
@triton.heuristics({
'USE_OFFSETS': lambda args: args['offsets'] is not None,
'USE_BLOCK_COUNTS': lambda args: isinstance(args['block_counts'], torch.Tensor),
})
@triton.heuristics(
{
"USE_OFFSETS": lambda args: args["offsets"] is not None,
"USE_BLOCK_COUNTS": lambda args: isinstance(args["block_counts"], torch.Tensor),
}
)
@triton.autotune(
configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8]],
key=['BS', 'BK', 'BV'],
key=["BS", "BK", "BV"],
)
@triton.jit
def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, block_indices,
block_counts, offsets, token_indices, T, H: tl.constexpr,
HQ: tl.constexpr, G: tl.constexpr, K: tl.constexpr, V: tl.constexpr,
S: tl.constexpr, BS: tl.constexpr, WS: tl.constexpr, BK: tl.constexpr,
BV: tl.constexpr, USE_OFFSETS: tl.constexpr,
USE_BLOCK_COUNTS: tl.constexpr):
def parallel_nsa_fwd_kernel(
q,
k,
v,
o_slc,
o_swa,
lse_slc,
lse_swa,
scale,
block_indices,
block_counts,
offsets,
token_indices,
T,
H: tl.constexpr,
HQ: tl.constexpr,
G: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
S: tl.constexpr,
BS: tl.constexpr,
WS: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
USE_OFFSETS: tl.constexpr,
USE_BLOCK_COUNTS: tl.constexpr,
):
i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_b, i_h = i_bh // H, i_bh % H
if USE_OFFSETS:
i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 +
1).to(tl.int32)
i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + 1).to(tl.int32)
bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
T = eos - bos
else:
......@@ -484,20 +574,18 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc
else:
NS = S
p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK),
(1, 0))
p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0))
# the Q block is kept in the shared memory throughout the whole kernel
# [G, BK]
b_q = tl.load(p_q, boundary_check=(0, 1))
b_q = (b_q * scale).to(b_q.dtype)
p_o_slc = tl.make_block_ptr(o_slc + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV),
(G, BV), (1, 0))
p_o_slc = tl.make_block_ptr(o_slc + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0))
p_lse_slc = lse_slc + (bos + i_t) * HQ + i_h * G + tl.arange(0, G)
# [G, BV]
b_o_slc = tl.zeros([G, BV], dtype=tl.float32)
b_m_slc = tl.full([G], float('-inf'), dtype=tl.float32)
b_m_slc = tl.full([G], float("-inf"), dtype=tl.float32)
b_acc_slc = tl.zeros([G], dtype=tl.float32)
for i in range(NS):
i_s = tl.load(block_indices + i).to(tl.int32) * BS
......@@ -510,7 +598,7 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc
b_v_slc = tl.load(p_v_slc, boundary_check=(0, 1))
# [G, BS]
b_s_slc = tl.dot(b_q, b_k_slc)
b_s_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_slc, float('-inf'))
b_s_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_slc, float("-inf"))
# [G]
b_m_slc, b_mp_slc = tl.maximum(b_m_slc, tl.max(b_s_slc, 1)), b_m_slc
......@@ -529,13 +617,12 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc
tl.store(p_lse_slc, b_m_slc.to(p_lse_slc.dtype.element_ty))
if WS > 0:
p_o_swa = tl.make_block_ptr(o_swa + (bos + i_t) * HQ * V, (HQ, V), (V, 1),
(i_h * G, i_v * BV), (G, BV), (1, 0))
p_o_swa = tl.make_block_ptr(o_swa + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0))
p_lse_swa = lse_swa + (bos + i_t) * HQ + i_h * G + tl.arange(0, G)
# [G, BV]
b_o_swa = tl.zeros([G, BV], dtype=tl.float32)
b_m_swa = tl.full([G], float('-inf'), dtype=tl.float32)
b_m_swa = tl.full([G], float("-inf"), dtype=tl.float32)
b_acc_swa = tl.zeros([G], dtype=tl.float32)
for i_s in range(max(0, i_t - WS + 1), i_t + 1, BS):
p_k_swa = tl.make_block_ptr(k, (K, T), (1, H * K), (0, i_s), (BK, BS), (0, 1))
......@@ -546,7 +633,7 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc
b_v_swa = tl.load(p_v_swa, boundary_check=(0, 1))
# [G, BS]
b_s_swa = tl.dot(b_q, b_k_swa)
b_s_swa = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_swa, float('-inf'))
b_s_swa = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_swa, float("-inf"))
# [G]
b_m_swa, b_mp_swa = tl.maximum(b_m_swa, tl.max(b_s_swa, 1)), b_m_swa
......@@ -593,14 +680,8 @@ def parallel_nsa_block_mask(
block_mask = torch.zeros(B, T, H, NS, dtype=torch.bool, device=block_indices.device)
parallel_nsa_kernel_mask[(T, B, H * S)](
block_indices=block_indices,
block_counts=block_counts,
block_mask=block_mask,
T=T,
H=H,
S=S,
BS=BS,
NS=NS)
block_indices=block_indices, block_counts=block_counts, block_mask=block_mask, T=T, H=H, S=S, BS=BS, NS=NS
)
return block_mask
......@@ -676,7 +757,8 @@ def parallel_nsa_bwd(
BS=BS,
WS=WS,
BK=BK,
BV=BV)
BV=BV,
)
dq = dq.sum(0)
if offsets is not None:
......@@ -719,14 +801,14 @@ def parallel_nsa_bwd(
BS=BS,
WS=WS,
BK=BK,
BV=BV)
BV=BV,
)
dk = dk.sum(0)
return dq, dk, dv
@torch.compile
class ParallelNSAFunction(torch.autograd.Function):
@staticmethod
@contiguous
@autocast_custom_fwd
......@@ -749,7 +831,8 @@ class ParallelNSAFunction(torch.autograd.Function):
window_size=window_size,
scale=scale,
offsets=offsets,
token_indices=token_indices)
token_indices=token_indices,
)
ctx.save_for_backward(q, k, v, o_slc, lse_slc, o_swa, lse_swa)
ctx.block_indices = block_indices
ctx.block_counts = block_counts
......@@ -781,22 +864,25 @@ class ParallelNSAFunction(torch.autograd.Function):
window_size=ctx.window_size,
scale=ctx.scale,
offsets=ctx.offsets,
token_indices=ctx.token_indices)
token_indices=ctx.token_indices,
)
return dq.to(q), dk.to(k), dv.to(v), None, None, None, None, None, None, None, None
def parallel_nsa(q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g_slc: torch.Tensor,
g_swa: torch.Tensor,
block_indices: torch.LongTensor,
block_counts: Optional[Union[torch.LongTensor, int]] = None,
block_size: int = 64,
window_size: int = 0,
scale: Optional[float] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
head_first: bool = False) -> torch.Tensor:
def parallel_nsa(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g_slc: torch.Tensor,
g_swa: torch.Tensor,
block_indices: torch.LongTensor,
block_counts: Optional[Union[torch.LongTensor, int]] = None,
block_size: int = 64,
window_size: int = 0,
scale: Optional[float] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
head_first: bool = False,
) -> torch.Tensor:
r"""
Args:
q (torch.Tensor):
......@@ -836,51 +922,49 @@ def parallel_nsa(q: torch.Tensor,
Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`.
"""
if scale is None:
scale = k.shape[-1]**-0.5
scale = k.shape[-1] ** -0.5
if cu_seqlens is not None:
assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided"
if head_first:
q, k, v, block_indices = map(lambda x: rearrange(x, 'b h t d -> b t h d'),
(q, k, v, block_indices))
g_slc, g_swa = map(lambda x: rearrange(x, 'b h t -> b t h'), (g_slc, g_swa))
q, k, v, block_indices = map(lambda x: rearrange(x, "b h t d -> b t h d"), (q, k, v, block_indices))
g_slc, g_swa = map(lambda x: rearrange(x, "b h t -> b t h"), (g_slc, g_swa))
if isinstance(block_counts, torch.Tensor):
block_counts = rearrange(block_counts, 'b h t -> b t h')
block_counts = rearrange(block_counts, "b h t -> b t h")
assert q.shape[2] % (k.shape[2] * 16) == 0, "Group size must be a multiple of 16 in NSA"
if isinstance(block_counts, int):
block_indices = block_indices[:, :, :, :block_counts]
block_counts = None
o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size,
window_size, scale, cu_seqlens)
o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, window_size, scale, cu_seqlens)
if window_size > 0:
o = torch.addcmul(o_slc * g_slc.unsqueeze(-1), o_swa, g_swa.unsqueeze(-1))
else:
o = o_slc * g_slc.unsqueeze(-1)
if head_first:
o = rearrange(o, 'b t h d -> b h t d')
o = rearrange(o, "b t h d -> b h t d")
return o
if __name__ == "__main__":
B, T, H, HQ, D, S, block_size, dtype = 2, 64, 1, 16, 32, 1, 32, torch.float16
torch.random.manual_seed(0)
q = torch.randn((B, T, HQ, D), dtype=dtype, device='cuda').requires_grad_(True)
k = torch.randn((B, T, H, D), dtype=dtype, device='cuda').requires_grad_(True)
v = torch.randn((B, T, H, D), dtype=dtype, device='cuda').requires_grad_(True)
g_slc = torch.ones((B, T, HQ), dtype=dtype, device='cuda').requires_grad_(True)
g_swa = torch.ones((B, T, HQ), dtype=dtype, device='cuda').requires_grad_(True)
do = torch.randn((B, T, HQ, D), dtype=dtype, device='cuda')
block_indices = torch.full((B, T, H, S), T, dtype=torch.long, device='cuda')
q = torch.randn((B, T, HQ, D), dtype=dtype, device="cuda").requires_grad_(True)
k = torch.randn((B, T, H, D), dtype=dtype, device="cuda").requires_grad_(True)
v = torch.randn((B, T, H, D), dtype=dtype, device="cuda").requires_grad_(True)
g_slc = torch.ones((B, T, HQ), dtype=dtype, device="cuda").requires_grad_(True)
g_swa = torch.ones((B, T, HQ), dtype=dtype, device="cuda").requires_grad_(True)
do = torch.randn((B, T, HQ, D), dtype=dtype, device="cuda")
block_indices = torch.full((B, T, H, S), T, dtype=torch.long, device="cuda")
for b in range(B):
for t in range(T):
for h in range(H):
i_i = torch.randperm(max(1, (t // block_size)))[:S]
block_indices[b, t, h, :len(i_i)] = i_i
block_indices[b, t, h, : len(i_i)] = i_i
block_indices = block_indices.sort(-1)[0]
block_counts = torch.randint(1, S + 1, (B, T, H), device='cuda')
block_counts = torch.randint(1, S + 1, (B, T, H), device="cuda")
ref = naive_nsa(
q=q,
......
......@@ -8,6 +8,7 @@ import triton
import triton.language as tl
import fla
if parse(fla.__version__) < parse("0.2.1"):
from fla.ops.common.utils import prepare_token_indices
else:
......@@ -17,21 +18,44 @@ from reference import naive_nsa
from einops import rearrange
@triton.heuristics({
'USE_OFFSETS': lambda args: args['offsets'] is not None,
'USE_BLOCK_COUNTS': lambda args: isinstance(args['block_counts'], torch.Tensor),
})
@triton.heuristics(
{
"USE_OFFSETS": lambda args: args["offsets"] is not None,
"USE_BLOCK_COUNTS": lambda args: isinstance(args["block_counts"], torch.Tensor),
}
)
@triton.autotune(
configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1]],
key=['BS', 'BK', 'BV'],
key=["BS", "BK", "BV"],
)
@triton.jit
def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, block_indices,
block_counts, offsets, token_indices, T, H: tl.constexpr,
HQ: tl.constexpr, G: tl.constexpr, K: tl.constexpr, V: tl.constexpr,
S: tl.constexpr, BS: tl.constexpr, WS: tl.constexpr, BK: tl.constexpr,
BV: tl.constexpr, USE_OFFSETS: tl.constexpr,
USE_BLOCK_COUNTS: tl.constexpr):
def parallel_nsa_fwd_kernel(
q,
k,
v,
o_slc,
o_swa,
lse_slc,
lse_swa,
scale,
block_indices,
block_counts,
offsets,
token_indices,
T,
H: tl.constexpr,
HQ: tl.constexpr,
G: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
S: tl.constexpr,
BS: tl.constexpr,
WS: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
USE_OFFSETS: tl.constexpr,
USE_BLOCK_COUNTS: tl.constexpr,
):
i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_b, i_h = i_bh // H, i_bh % H
......@@ -46,20 +70,18 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc
# else:
NS = S
p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK),
(1, 0))
p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0))
# the Q block is kept in the shared memory throughout the whole kernel
# [G, BK]
b_q = tl.load(p_q, boundary_check=(0, 1))
b_q = (b_q * scale).to(b_q.dtype)
p_o_slc = tl.make_block_ptr(o_slc + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV),
(G, BV), (1, 0))
p_o_slc = tl.make_block_ptr(o_slc + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0))
p_lse_slc = lse_slc + (bos + i_t) * HQ + i_h * G + tl.arange(0, G)
# [G, BV]
b_o_slc = tl.zeros([G, BV], dtype=tl.float32)
b_m_slc = tl.full([G], float('-inf'), dtype=tl.float32)
b_m_slc = tl.full([G], float("-inf"), dtype=tl.float32)
b_acc_slc = tl.zeros([G], dtype=tl.float32)
for i in range(NS):
i_s = tl.load(block_indices + i).to(tl.int32) * BS
......@@ -72,7 +94,7 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc
b_v_slc = tl.load(p_v_slc, boundary_check=(0, 1))
# [G, BS]
b_s_slc = tl.dot(b_q, b_k_slc)
b_s_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_slc, float('-inf'))
b_s_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_slc, float("-inf"))
# [G]
b_m_slc, b_mp_slc = tl.maximum(b_m_slc, tl.max(b_s_slc, 1)), b_m_slc
......@@ -92,7 +114,6 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc
class ParallelNSAFunction(torch.autograd.Function):
@staticmethod
@contiguous
@autocast_custom_fwd
......@@ -105,8 +126,7 @@ class ParallelNSAFunction(torch.autograd.Function):
# [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
token_indices = prepare_token_indices(offsets) if offsets is not None else None
o, lse = parallel_nsa_fwd(
q=q, k=k, v=v, block_indices=block_indices, block_size=block_size, scale=scale)
o, lse = parallel_nsa_fwd(q=q, k=k, v=v, block_indices=block_indices, block_size=block_size, scale=scale)
ctx.save_for_backward(q, k, v, o, lse)
ctx.block_indices = block_indices
ctx.block_size = block_size
......@@ -177,7 +197,6 @@ def parallel_nsa_fwd(
@torch.compile
class ParallelNSAFunction(torch.autograd.Function):
@staticmethod
@contiguous
@autocast_custom_fwd
......@@ -200,7 +219,8 @@ class ParallelNSAFunction(torch.autograd.Function):
window_size=window_size,
scale=scale,
offsets=offsets,
token_indices=token_indices)
token_indices=token_indices,
)
ctx.save_for_backward(q, k, v, o_slc, lse_slc, o_swa, lse_swa)
ctx.block_indices = block_indices
ctx.block_counts = block_counts
......@@ -212,18 +232,20 @@ class ParallelNSAFunction(torch.autograd.Function):
return o_slc.to(q.dtype), o_swa.to(q.dtype) if o_swa is not None else o_swa
def parallel_nsa(q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g_slc: torch.Tensor,
g_swa: torch.Tensor,
block_indices: torch.LongTensor,
block_counts: Optional[Union[torch.LongTensor, int]] = None,
block_size: int = 64,
window_size: int = 0,
scale: Optional[float] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
head_first: bool = False) -> torch.Tensor:
def parallel_nsa(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g_slc: torch.Tensor,
g_swa: torch.Tensor,
block_indices: torch.LongTensor,
block_counts: Optional[Union[torch.LongTensor, int]] = None,
block_size: int = 64,
window_size: int = 0,
scale: Optional[float] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
head_first: bool = False,
) -> torch.Tensor:
r"""
Args:
q (torch.Tensor):
......@@ -263,51 +285,49 @@ def parallel_nsa(q: torch.Tensor,
Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`.
"""
if scale is None:
scale = k.shape[-1]**-0.5
scale = k.shape[-1] ** -0.5
if cu_seqlens is not None:
assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided"
if head_first:
q, k, v, block_indices = map(lambda x: rearrange(x, 'b h t d -> b t h d'),
(q, k, v, block_indices))
g_slc, g_swa = map(lambda x: rearrange(x, 'b h t -> b t h'), (g_slc, g_swa))
q, k, v, block_indices = map(lambda x: rearrange(x, "b h t d -> b t h d"), (q, k, v, block_indices))
g_slc, g_swa = map(lambda x: rearrange(x, "b h t -> b t h"), (g_slc, g_swa))
if isinstance(block_counts, torch.Tensor):
block_counts = rearrange(block_counts, 'b h t -> b t h')
block_counts = rearrange(block_counts, "b h t -> b t h")
assert q.shape[2] % (k.shape[2] * 16) == 0, "Group size must be a multiple of 16 in NSA"
if isinstance(block_counts, int):
block_indices = block_indices[:, :, :, :block_counts]
block_counts = None
o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size,
window_size, scale, cu_seqlens)
o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, window_size, scale, cu_seqlens)
if window_size > 0:
o = torch.addcmul(o_slc * g_slc.unsqueeze(-1), o_swa, g_swa.unsqueeze(-1))
else:
o = o_slc * g_slc.unsqueeze(-1)
if head_first:
o = rearrange(o, 'b t h d -> b h t d')
o = rearrange(o, "b t h d -> b h t d")
return o
if __name__ == "__main__":
B, T, H, HQ, D, S, block_size, dtype = 2, 64, 1, 16, 32, 1, 32, torch.float16
torch.random.manual_seed(0)
q = torch.randn((B, T, HQ, D), dtype=dtype, device='cuda').requires_grad_(True)
k = torch.randn((B, T, H, D), dtype=dtype, device='cuda').requires_grad_(True)
v = torch.randn((B, T, H, D), dtype=dtype, device='cuda').requires_grad_(True)
g_slc = torch.ones((B, T, HQ), dtype=dtype, device='cuda').requires_grad_(True)
g_swa = torch.ones((B, T, HQ), dtype=dtype, device='cuda').requires_grad_(True)
do = torch.randn((B, T, HQ, D), dtype=dtype, device='cuda')
block_indices = torch.full((B, T, H, S), T, dtype=torch.long, device='cuda')
q = torch.randn((B, T, HQ, D), dtype=dtype, device="cuda").requires_grad_(True)
k = torch.randn((B, T, H, D), dtype=dtype, device="cuda").requires_grad_(True)
v = torch.randn((B, T, H, D), dtype=dtype, device="cuda").requires_grad_(True)
g_slc = torch.ones((B, T, HQ), dtype=dtype, device="cuda").requires_grad_(True)
g_swa = torch.ones((B, T, HQ), dtype=dtype, device="cuda").requires_grad_(True)
do = torch.randn((B, T, HQ, D), dtype=dtype, device="cuda")
block_indices = torch.full((B, T, H, S), T, dtype=torch.long, device="cuda")
for b in range(B):
for t in range(T):
for h in range(H):
i_i = torch.randperm(max(1, (t // block_size)))[:S]
block_indices[b, t, h, :len(i_i)] = i_i
block_indices[b, t, h, : len(i_i)] = i_i
block_indices = block_indices.sort(-1)[0]
block_counts = torch.randint(1, S + 1, (B, T, H), device='cuda')
block_counts = torch.randint(1, S + 1, (B, T, H), device="cuda")
ref = naive_nsa(
q=q,
......
......@@ -8,6 +8,7 @@ import triton
import triton.language as tl
import fla
if parse(fla.__version__) < parse("0.2.1"):
from fla.ops.common.utils import prepare_token_indices
else:
......@@ -17,27 +18,49 @@ from reference import naive_nsa
from einops import rearrange
@triton.heuristics({
'USE_OFFSETS': lambda args: args['offsets'] is not None,
'USE_BLOCK_COUNTS': lambda args: isinstance(args['block_counts'], torch.Tensor),
})
@triton.heuristics(
{
"USE_OFFSETS": lambda args: args["offsets"] is not None,
"USE_BLOCK_COUNTS": lambda args: isinstance(args["block_counts"], torch.Tensor),
}
)
@triton.autotune(
configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8]],
key=['BS', 'BK', 'BV'],
key=["BS", "BK", "BV"],
)
@triton.jit
def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, block_indices,
block_counts, offsets, token_indices, T, H: tl.constexpr,
HQ: tl.constexpr, G: tl.constexpr, K: tl.constexpr, V: tl.constexpr,
S: tl.constexpr, BS: tl.constexpr, WS: tl.constexpr, BK: tl.constexpr,
BV: tl.constexpr, USE_OFFSETS: tl.constexpr,
USE_BLOCK_COUNTS: tl.constexpr):
def parallel_nsa_fwd_kernel(
q,
k,
v,
o_slc,
o_swa,
lse_slc,
lse_swa,
scale,
block_indices,
block_counts,
offsets,
token_indices,
T,
H: tl.constexpr,
HQ: tl.constexpr,
G: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
S: tl.constexpr,
BS: tl.constexpr,
WS: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
USE_OFFSETS: tl.constexpr,
USE_BLOCK_COUNTS: tl.constexpr,
):
i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_b, i_h = i_bh // H, i_bh % H
if USE_OFFSETS:
i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 +
1).to(tl.int32)
i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + 1).to(tl.int32)
bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
T = eos - bos
else:
......@@ -52,20 +75,18 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc
else:
NS = S
p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK),
(1, 0))
p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0))
# the Q block is kept in the shared memory throughout the whole kernel
# [G, BK]
b_q = tl.load(p_q, boundary_check=(0, 1))
b_q = (b_q * scale).to(b_q.dtype)
p_o_slc = tl.make_block_ptr(o_slc + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV),
(G, BV), (1, 0))
p_o_slc = tl.make_block_ptr(o_slc + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0))
p_lse_slc = lse_slc + (bos + i_t) * HQ + i_h * G + tl.arange(0, G)
# [G, BV]
b_o_slc = tl.zeros([G, BV], dtype=tl.float32)
b_m_slc = tl.full([G], float('-inf'), dtype=tl.float32)
b_m_slc = tl.full([G], float("-inf"), dtype=tl.float32)
b_acc_slc = tl.zeros([G], dtype=tl.float32)
for i in range(NS):
i_s = tl.load(block_indices + i).to(tl.int32) * BS
......@@ -78,7 +99,7 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc
b_v_slc = tl.load(p_v_slc, boundary_check=(0, 1))
# [G, BS]
b_s_slc = tl.dot(b_q, b_k_slc)
b_s_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_slc, float('-inf'))
b_s_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_slc, float("-inf"))
# [G]
b_m_slc, b_mp_slc = tl.maximum(b_m_slc, tl.max(b_s_slc, 1)), b_m_slc
......@@ -97,13 +118,12 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc
tl.store(p_lse_slc, b_m_slc.to(p_lse_slc.dtype.element_ty))
if WS > 0:
p_o_swa = tl.make_block_ptr(o_swa + (bos + i_t) * HQ * V, (HQ, V), (V, 1),
(i_h * G, i_v * BV), (G, BV), (1, 0))
p_o_swa = tl.make_block_ptr(o_swa + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0))
p_lse_swa = lse_swa + (bos + i_t) * HQ + i_h * G + tl.arange(0, G)
# [G, BV]
b_o_swa = tl.zeros([G, BV], dtype=tl.float32)
b_m_swa = tl.full([G], float('-inf'), dtype=tl.float32)
b_m_swa = tl.full([G], float("-inf"), dtype=tl.float32)
b_acc_swa = tl.zeros([G], dtype=tl.float32)
for i_s in range(max(0, i_t - WS + 1), i_t + 1, BS):
p_k_swa = tl.make_block_ptr(k, (K, T), (1, H * K), (0, i_s), (BK, BS), (0, 1))
......@@ -114,7 +134,7 @@ def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, bloc
b_v_swa = tl.load(p_v_swa, boundary_check=(0, 1))
# [G, BS]
b_s_swa = tl.dot(b_q, b_k_swa)
b_s_swa = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_swa, float('-inf'))
b_s_swa = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_swa, float("-inf"))
# [G]
b_m_swa, b_mp_swa = tl.maximum(b_m_swa, tl.max(b_s_swa, 1)), b_m_swa
......@@ -196,7 +216,6 @@ def parallel_nsa_fwd(
@torch.compile
class ParallelNSAFunction(torch.autograd.Function):
@staticmethod
@contiguous
@autocast_custom_fwd
......@@ -219,7 +238,8 @@ class ParallelNSAFunction(torch.autograd.Function):
window_size=window_size,
scale=scale,
offsets=offsets,
token_indices=token_indices)
token_indices=token_indices,
)
ctx.save_for_backward(q, k, v, o_slc, lse_slc, o_swa, lse_swa)
ctx.block_indices = block_indices
ctx.block_counts = block_counts
......@@ -231,18 +251,20 @@ class ParallelNSAFunction(torch.autograd.Function):
return o_slc.to(q.dtype), o_swa.to(q.dtype) if o_swa is not None else o_swa
def parallel_nsa(q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g_slc: torch.Tensor,
g_swa: torch.Tensor,
block_indices: torch.LongTensor,
block_counts: Optional[Union[torch.LongTensor, int]] = None,
block_size: int = 64,
window_size: int = 0,
scale: Optional[float] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
head_first: bool = False) -> torch.Tensor:
def parallel_nsa(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g_slc: torch.Tensor,
g_swa: torch.Tensor,
block_indices: torch.LongTensor,
block_counts: Optional[Union[torch.LongTensor, int]] = None,
block_size: int = 64,
window_size: int = 0,
scale: Optional[float] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
head_first: bool = False,
) -> torch.Tensor:
r"""
Args:
q (torch.Tensor):
......@@ -282,29 +304,27 @@ def parallel_nsa(q: torch.Tensor,
Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`.
"""
if scale is None:
scale = k.shape[-1]**-0.5
scale = k.shape[-1] ** -0.5
if cu_seqlens is not None:
assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided"
if head_first:
q, k, v, block_indices = map(lambda x: rearrange(x, 'b h t d -> b t h d'),
(q, k, v, block_indices))
g_slc, g_swa = map(lambda x: rearrange(x, 'b h t -> b t h'), (g_slc, g_swa))
q, k, v, block_indices = map(lambda x: rearrange(x, "b h t d -> b t h d"), (q, k, v, block_indices))
g_slc, g_swa = map(lambda x: rearrange(x, "b h t -> b t h"), (g_slc, g_swa))
if isinstance(block_counts, torch.Tensor):
block_counts = rearrange(block_counts, 'b h t -> b t h')
block_counts = rearrange(block_counts, "b h t -> b t h")
assert q.shape[2] % (k.shape[2] * 16) == 0, "Group size must be a multiple of 16 in NSA"
if isinstance(block_counts, int):
block_indices = block_indices[:, :, :, :block_counts]
block_counts = None
o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size,
window_size, scale, cu_seqlens)
o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, window_size, scale, cu_seqlens)
if window_size > 0:
o = torch.addcmul(o_slc * g_slc.unsqueeze(-1), o_swa, g_swa.unsqueeze(-1))
else:
o = o_slc * g_slc.unsqueeze(-1)
if head_first:
o = rearrange(o, 'b t h d -> b h t d')
o = rearrange(o, "b t h d -> b h t d")
return o
......@@ -312,38 +332,35 @@ if __name__ == "__main__":
N, T, H, HQ, D, S, block_size, dtype = 2, 64, 1, 16, 64, 1, 32, torch.float16
torch.manual_seed(42)
# randomly split the sequence into N segments
offsets = torch.cat([
torch.tensor([0], dtype=torch.long),
torch.arange(16, T)[torch.randperm(T - 1)[:N - 1]],
torch.tensor([T], dtype=torch.long)
], 0).cuda().sort()[0]
offsets = (
torch.cat(
[torch.tensor([0], dtype=torch.long), torch.arange(16, T)[torch.randperm(T - 1)[: N - 1]], torch.tensor([T], dtype=torch.long)],
0,
)
.cuda()
.sort()[0]
)
# offsets.shape is [N+1]
# seq-first required for inputs with variable lengths
perm_q = torch.randperm(T, device='cuda')
perm_k = torch.randperm(T, device='cuda')
perm_v = torch.randperm(T, device='cuda')
q = torch.linspace(
0, 1, steps=T, dtype=dtype,
device='cuda')[perm_q].view(1, T, 1, 1).expand(1, T, HQ, D).clone().requires_grad_(True)
k = torch.linspace(
0, 1, steps=T, dtype=dtype,
device='cuda')[perm_k].view(1, T, 1, 1).expand(1, T, H, D).clone().requires_grad_(True)
v = torch.linspace(
0, 1, steps=T, dtype=dtype,
device='cuda')[perm_v].view(1, T, 1, 1).expand(1, T, H, D).clone().requires_grad_(True)
g_slc = torch.rand((1, T, HQ), dtype=dtype, device='cuda').requires_grad_(True)
g_swa = torch.rand((1, T, HQ), dtype=dtype, device='cuda').requires_grad_(True)
do = torch.randn((1, T, HQ, D), dtype=dtype, device='cuda')
perm_q = torch.randperm(T, device="cuda")
perm_k = torch.randperm(T, device="cuda")
perm_v = torch.randperm(T, device="cuda")
q = torch.linspace(0, 1, steps=T, dtype=dtype, device="cuda")[perm_q].view(1, T, 1, 1).expand(1, T, HQ, D).clone().requires_grad_(True)
k = torch.linspace(0, 1, steps=T, dtype=dtype, device="cuda")[perm_k].view(1, T, 1, 1).expand(1, T, H, D).clone().requires_grad_(True)
v = torch.linspace(0, 1, steps=T, dtype=dtype, device="cuda")[perm_v].view(1, T, 1, 1).expand(1, T, H, D).clone().requires_grad_(True)
g_slc = torch.rand((1, T, HQ), dtype=dtype, device="cuda").requires_grad_(True)
g_swa = torch.rand((1, T, HQ), dtype=dtype, device="cuda").requires_grad_(True)
do = torch.randn((1, T, HQ, D), dtype=dtype, device="cuda")
token_indices = prepare_token_indices(offsets).tolist()
block_indices = torch.full((1, T, H, S), T, dtype=torch.long, device='cuda')
block_indices = torch.full((1, T, H, S), T, dtype=torch.long, device="cuda")
for i in range(T):
_, t = token_indices[i]
for h in range(H):
i_i = torch.randperm(max(1, triton.cdiv(t, block_size)))[:S]
block_indices[0, i, h, :len(i_i)] = i_i
block_indices[0, i, h, : len(i_i)] = i_i
block_indices = block_indices.sort(-1)[0]
block_counts = torch.randint(1, S + 1, (1, T, H), device='cuda')
block_counts = torch.randint(1, S + 1, (1, T, H), device="cuda")
ref = naive_nsa(
q=q,
......@@ -354,7 +371,8 @@ if __name__ == "__main__":
block_indices=block_indices,
block_counts=block_counts,
block_size=block_size,
cu_seqlens=offsets)
cu_seqlens=offsets,
)
tri = parallel_nsa(
q=q,
......@@ -365,7 +383,8 @@ if __name__ == "__main__":
block_indices=block_indices,
block_counts=block_counts,
block_size=block_size,
cu_seqlens=offsets)
cu_seqlens=offsets,
)
print("tri", tri)
print("ref", ref)
......
......@@ -6,18 +6,20 @@ from typing import Union
from einops import rearrange, repeat
def naive_nsa(q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g_slc: torch.Tensor,
g_swa: torch.Tensor,
block_indices: torch.LongTensor,
block_counts: Optional[Union[torch.LongTensor, int]] = None,
block_size: int = 64,
window_size: int = 0,
scale: Optional[float] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
head_first: bool = False) -> torch.Tensor:
def naive_nsa(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g_slc: torch.Tensor,
g_swa: torch.Tensor,
block_indices: torch.LongTensor,
block_counts: Optional[Union[torch.LongTensor, int]] = None,
block_size: int = 64,
window_size: int = 0,
scale: Optional[float] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
head_first: bool = False,
) -> torch.Tensor:
r"""
Args:
q (torch.Tensor):
......@@ -57,26 +59,24 @@ def naive_nsa(q: torch.Tensor,
Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`.
"""
if scale is None:
scale = k.shape[-1]**-0.5
scale = k.shape[-1] ** -0.5
if cu_seqlens is not None:
assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided"
if head_first:
raise RuntimeError(
"Sequences with variable lengths are not supported for head-first mode")
raise RuntimeError("Sequences with variable lengths are not supported for head-first mode")
if head_first:
q, k, v, block_indices = map(lambda x: rearrange(x, 'b h t d -> b t h d'),
(q, k, v, block_indices))
g_slc, g_swa = map(lambda x: rearrange(x, 'b h t -> b t h'), (g_slc, g_swa))
q, k, v, block_indices = map(lambda x: rearrange(x, "b h t d -> b t h d"), (q, k, v, block_indices))
g_slc, g_swa = map(lambda x: rearrange(x, "b h t -> b t h"), (g_slc, g_swa))
if isinstance(block_counts, torch.Tensor):
block_counts = rearrange(block_counts, 'b h t -> b t h')
block_counts = rearrange(block_counts, "b h t -> b t h")
dtype = q.dtype
G = q.shape[2] // k.shape[2]
BS = block_size
S = block_indices.shape[-1]
k, v, block_indices = (repeat(x, 'b t h d -> b t (h g) d', g=G) for x in (k, v, block_indices))
k, v, block_indices = (repeat(x, "b t h d -> b t (h g) d", g=G) for x in (k, v, block_indices))
if isinstance(block_counts, torch.Tensor):
block_counts = repeat(block_counts, 'b t h -> b t (h g)', g=G)
block_counts = repeat(block_counts, "b t h -> b t (h g)", g=G)
c = torch.arange(S).repeat_interleave(BS).unsqueeze(1).expand(-1, q.shape[2]).to(q.device)
q, k, v = map(lambda x: x.float(), (q, k, v))
......@@ -86,14 +86,11 @@ def naive_nsa(q: torch.Tensor,
if cu_seqlens is None:
varlen = False
B, T = q.shape[:2]
cu_seqlens = torch.cat(
[block_indices.new_tensor(range(0, B * T, T)),
block_indices.new_tensor([B * T])])
cu_seqlens = torch.cat([block_indices.new_tensor(range(0, B * T, T)), block_indices.new_tensor([B * T])])
for i in range(len(cu_seqlens) - 1):
if not varlen:
q_b, k_b, v_b, g_slc_b, g_swa_b, i_b = q[i], k[i], v[i], g_slc[i], g_swa[
i], block_indices[i]
q_b, k_b, v_b, g_slc_b, g_swa_b, i_b = q[i], k[i], v[i], g_slc[i], g_swa[i], block_indices[i]
if isinstance(block_counts, torch.Tensor):
s_b = block_counts[i]
else:
......@@ -101,10 +98,10 @@ def naive_nsa(q: torch.Tensor,
else:
T = cu_seqlens[i + 1] - cu_seqlens[i]
q_b, k_b, v_b, g_slc_b, g_swa_b, i_b = map(
lambda x: x[0][cu_seqlens[i]:cu_seqlens[i + 1]],
(q, k, v, g_slc, g_swa, block_indices))
lambda x: x[0][cu_seqlens[i] : cu_seqlens[i + 1]], (q, k, v, g_slc, g_swa, block_indices)
)
if isinstance(block_counts, torch.Tensor):
s_b = block_counts[0][cu_seqlens[i]:cu_seqlens[i + 1]]
s_b = block_counts[0][cu_seqlens[i] : cu_seqlens[i + 1]]
else:
s_b = block_counts
......@@ -126,34 +123,28 @@ def naive_nsa(q: torch.Tensor,
else:
s_i = s_b
# [S*BS, HQ, -1]
k_i_slc, v_i_slc = map(
lambda x: x.gather(
0,
i_i.clamp(0, T - 1).unsqueeze(-1).expand(*i_i.shape, x.shape[-1])), (k_b, v_b))
k_i_slc, v_i_slc = map(lambda x: x.gather(0, i_i.clamp(0, T - 1).unsqueeze(-1).expand(*i_i.shape, x.shape[-1])), (k_b, v_b))
# [S*BS, HQ]
attn_slc = torch.einsum('h d, n h d -> n h', q_i, k_i_slc).masked_fill(
torch.logical_or(i_i < 0, i_i > i_q) |
(c >= s_i if block_counts is not None else False), float('-inf')).softmax(0)
attn_slc = (
torch.einsum("h d, n h d -> n h", q_i, k_i_slc)
.masked_fill(torch.logical_or(i_i < 0, i_i > i_q) | (c >= s_i if block_counts is not None else False), float("-inf"))
.softmax(0)
)
if not varlen:
o_slc[i, i_q] = torch.einsum('n h, n h v -> h v', attn_slc,
v_i_slc) * g_slc_i.unsqueeze(-1)
o_slc[i, i_q] = torch.einsum("n h, n h v -> h v", attn_slc, v_i_slc) * g_slc_i.unsqueeze(-1)
else:
o_slc[0][cu_seqlens[i] + i_q] = torch.einsum('n h, n h v -> h v', attn_slc,
v_i_slc) * g_slc_i.unsqueeze(-1)
o_slc[0][cu_seqlens[i] + i_q] = torch.einsum("n h, n h v -> h v", attn_slc, v_i_slc) * g_slc_i.unsqueeze(-1)
if window_size > 0:
k_i_swa, v_i_swa = map(lambda x: x[max(0, i_q - window_size + 1):i_q + 1],
(k_b, v_b))
attn_swa = torch.einsum('h d, n h d -> n h', q_i, k_i_swa).softmax(0)
k_i_swa, v_i_swa = map(lambda x: x[max(0, i_q - window_size + 1) : i_q + 1], (k_b, v_b))
attn_swa = torch.einsum("h d, n h d -> n h", q_i, k_i_swa).softmax(0)
if not varlen:
o_swa[i, i_q] = torch.einsum('n h, n h v -> h v', attn_swa,
v_i_swa) * g_swa_i.unsqueeze(-1)
o_swa[i, i_q] = torch.einsum("n h, n h v -> h v", attn_swa, v_i_swa) * g_swa_i.unsqueeze(-1)
else:
o_swa[0][cu_seqlens[i] + i_q] = torch.einsum('n h, n h v -> h v', attn_swa,
v_i_swa) * g_swa_i.unsqueeze(-1)
o_swa[0][cu_seqlens[i] + i_q] = torch.einsum("n h, n h v -> h v", attn_swa, v_i_swa) * g_swa_i.unsqueeze(-1)
if head_first:
o_slc = rearrange(o_slc, 'b t h d -> b h t d')
o_swa = rearrange(o_swa, 'b t h d -> b h t d')
o_slc = rearrange(o_slc, "b t h d -> b h t d")
o_swa = rearrange(o_swa, "b t h d -> b h t d")
return o_slc.to(dtype) + o_swa.to(dtype) if o_swa is not None else o_slc.to(dtype)
......@@ -187,7 +178,7 @@ def naive_nsa_simple(
o (torch.Tensor):
Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`.
"""
scale = k.shape[-1]**-0.5
scale = k.shape[-1] ** -0.5
dtype = q.dtype
HQ = q.shape[2]
......@@ -197,8 +188,8 @@ def naive_nsa_simple(
BS = block_size
S = block_indices.shape[-1]
SELECTED_BLOCKS_SIZE = S * BS
k, v, block_indices = (repeat(x, 'b t h d -> b t (h g) d', g=G) for x in (k, v, block_indices))
block_counts = repeat(block_counts, 'b t h -> b t (h g)', g=G)
k, v, block_indices = (repeat(x, "b t h d -> b t (h g) d", g=G) for x in (k, v, block_indices))
block_counts = repeat(block_counts, "b t h -> b t (h g)", g=G)
c = torch.arange(S).repeat_interleave(BS).unsqueeze(1).expand(-1, q.shape[2]).to(q.device)
q, k, v = map(lambda x: x.float(), (q, k, v))
o = torch.zeros_like(v)
......@@ -228,10 +219,10 @@ def naive_nsa_simple(
v_i[t, h] = v_b[selected_block_index, h, :]
# [S*BS, HQ]
attn = torch.einsum('h d, n h d -> n h', q_i, k_i)
attn = attn.masked_fill((i_i > i_q) | (c >= s_i), float('-inf'))
attn = torch.einsum("h d, n h d -> n h", q_i, k_i)
attn = attn.masked_fill((i_i > i_q) | (c >= s_i), float("-inf"))
attn = torch.softmax(attn, dim=0)
o[i, i_q] = torch.einsum('n h, n h v -> h v', attn, v_i)
o[i, i_q] = torch.einsum("n h, n h v -> h v", attn, v_i)
return o.to(dtype)
......@@ -265,7 +256,7 @@ def naive_nsa_simple_inference(
o (torch.Tensor):
Outputs of shape `[B, 1, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`.
"""
scale = k.shape[-1]**-0.5
scale = k.shape[-1] ** -0.5
dtype = q.dtype
HQ = q.shape[2]
......@@ -275,8 +266,8 @@ def naive_nsa_simple_inference(
BS = block_size
S = block_indices.shape[-1]
SELECTED_BLOCKS_SIZE = S * BS
k, v, block_indices = (repeat(x, 'b t h d -> b t (h g) d', g=G) for x in (k, v, block_indices))
block_counts = repeat(block_counts, 'b t h -> b t (h g)', g=G)
k, v, block_indices = (repeat(x, "b t h d -> b t (h g) d", g=G) for x in (k, v, block_indices))
block_counts = repeat(block_counts, "b t h -> b t (h g)", g=G)
c = torch.arange(S).repeat_interleave(BS).unsqueeze(1).expand(-1, q.shape[2]).to(q.device)
q, k, v = map(lambda x: x.float(), (q, k, v))
o = torch.zeros_like(q)
......@@ -306,9 +297,9 @@ def naive_nsa_simple_inference(
v_i[t, h] = v_b[selected_block_index, h, :]
# [S*BS, HQ]
attn = torch.einsum('h d, n h d -> n h', q_i, k_i)
attn = attn.masked_fill((c >= s_i), float('-inf'))
attn = torch.einsum("h d, n h d -> n h", q_i, k_i)
attn = attn.masked_fill((c >= s_i), float("-inf"))
attn = torch.softmax(attn, dim=0)
o[i, 0] = torch.einsum('n h, n h v -> h v', attn, v_i)
o[i, 0] = torch.einsum("n h, n h v -> h v", attn, v_i)
return o.to(dtype)
......@@ -28,11 +28,11 @@ def validate_tensor_match(a, b, tolerance=1e-8, tensor_name="tensor", should_rai
if should_raise:
assert False
if not torch.isclose(
a.masked_fill(a_finite, 0),
b.masked_fill(b_finite, 0),
rtol=0,
atol=0,
equal_nan=True,
a.masked_fill(a_finite, 0),
b.masked_fill(b_finite, 0),
rtol=0,
atol=0,
equal_nan=True,
).all():
display_error_message(f"{tensor_name} Error: nonfinite value mismatch")
if should_raise:
......@@ -55,13 +55,10 @@ def get_configs():
threads=[128, 256],
block_Q=[1, 2, 4],
)
return [{
k: v for k, v in zip(iter_params, values)
} for values in itertools.product(*iter_params.values())]
return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())]
class SupplyProg:
def __init__(self):
self.tensors_dict = {}
......@@ -88,7 +85,8 @@ supply_prog = SupplyProg()
@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},)
},
)
def mqa_attn_return_logits(
heads,
index_dim,
......@@ -113,16 +111,15 @@ def mqa_attn_return_logits(
@T.prim_func
def mqa_attn_return_logits_kernel(
IndexQ: T.Tensor(index_q_shape, dtype), # type: ignore
IndexK: T.Tensor(index_k_shape, dtype), # type: ignore
IndexKScale: T.Tensor(index_k_scale_shape, accum_dtype), # type: ignore
Logits: T.Tensor(logits_shape, accum_dtype), # type: ignore
Weights: T.Tensor([seq_len, heads], accum_dtype), # type: ignore
CuSeqLenKS: T.Tensor([seq_len], index_dtype), # type: ignore
CuSeqLenKE: T.Tensor([seq_len], index_dtype), # type: ignore
IndexQ: T.Tensor(index_q_shape, dtype), # type: ignore
IndexK: T.Tensor(index_k_shape, dtype), # type: ignore
IndexKScale: T.Tensor(index_k_scale_shape, accum_dtype), # type: ignore
Logits: T.Tensor(logits_shape, accum_dtype), # type: ignore
Weights: T.Tensor([seq_len, heads], accum_dtype), # type: ignore
CuSeqLenKS: T.Tensor([seq_len], index_dtype), # type: ignore
CuSeqLenKE: T.Tensor([seq_len], index_dtype), # type: ignore
):
with T.Kernel(T.ceildiv(seq_len, block_Q), threads=threads) as bx:
index_q_shared = T.alloc_shared([block_Q * heads, index_dim], dtype)
index_k_shared = T.alloc_shared([block_N, index_dim], dtype)
index_k_scale_fragment = T.alloc_fragment([block_N], accum_dtype)
......@@ -140,17 +137,14 @@ def mqa_attn_return_logits(
cu_k_e_max[0] = -2147483648
for bq_i in T.serial(block_Q):
cu_k_s_min[0] = T.min(cu_k_s_min[0], T.min(CuSeqLenKS[seq_len_i + bq_i],
seq_len_kv))
cu_k_s_min[0] = T.min(cu_k_s_min[0], T.min(CuSeqLenKS[seq_len_i + bq_i], seq_len_kv))
for bq_i in T.serial(block_Q):
cu_k_e_max[0] = T.max(cu_k_e_max[0], T.min(CuSeqLenKE[seq_len_i + bq_i],
seq_len_kv))
cu_k_e_max[0] = T.max(cu_k_e_max[0], T.min(CuSeqLenKE[seq_len_i + bq_i], seq_len_kv))
T.copy(IndexQ[seq_len_i * heads, 0], index_q_shared)
T.copy(Weights[seq_len_i, 0], weights)
for nbn_i in T.Pipelined(
T.ceildiv(cu_k_e_max[0] - cu_k_s_min[0], block_N), num_stages=num_stages):
for nbn_i in T.Pipelined(T.ceildiv(cu_k_e_max[0] - cu_k_s_min[0], block_N), num_stages=num_stages):
T.copy(IndexK[cu_k_s_min[0] + nbn_i * block_N, 0], index_k_shared)
T.copy(IndexKScale[cu_k_s_min[0] + nbn_i * block_N], index_k_scale_fragment)
......@@ -164,15 +158,14 @@ def mqa_attn_return_logits(
)
for bn_i, bq_i, h_i in T.Parallel(block_N, block_Q, heads):
s_reshaped[bn_i, bq_i,
h_i] = (T.max(s_reshaped[bn_i, bq_i, h_i], 0) *
weights[bq_i, h_i]) * index_k_scale_fragment[bn_i]
s_reshaped[bn_i, bq_i, h_i] = (T.max(s_reshaped[bn_i, bq_i, h_i], 0) * weights[bq_i, h_i]) * index_k_scale_fragment[
bn_i
]
T.reduce_sum(s_reshaped, logits, dim=-1, clear=True)
for bq_i, bn_i in T.Parallel(block_Q, block_N):
Logits[seq_len_i + bq_i, cu_k_s_min[0] + nbn_i * block_N + bn_i] = (
logits[bn_i, bq_i])
Logits[seq_len_i + bq_i, cu_k_s_min[0] + nbn_i * block_N + bn_i] = logits[bn_i, bq_i]
return mqa_attn_return_logits_kernel
......@@ -190,9 +183,9 @@ def clean_logits_(
@T.prim_func
def clean_logits_kernel(
Logits: T.Tensor([seq_len, seq_len_kv], dtype), # type: ignore
CuSeqLenKS: T.Tensor([seq_len], indices_dtype), # type: ignore
CuSeqLenKE: T.Tensor([seq_len], indices_dtype), # type: ignore
Logits: T.Tensor([seq_len, seq_len_kv], dtype), # type: ignore
CuSeqLenKS: T.Tensor([seq_len], indices_dtype), # type: ignore
CuSeqLenKE: T.Tensor([seq_len], indices_dtype), # type: ignore
):
with T.Kernel(seq_len, threads=threads) as bx:
tx = T.thread_binding(0, threads, thread="threadIdx.x")
......@@ -210,13 +203,7 @@ def clean_logits_(
return clean_logits_kernel
def mqa_attn_return_logits_interface(q,
kv,
kv_scales,
weights,
cu_seqlen_ks,
cu_seqlen_ke,
clean_logits=True):
def mqa_attn_return_logits_interface(q, kv, kv_scales, weights, cu_seqlen_ks, cu_seqlen_ke, clean_logits=True):
seq_len, heads, index_dim = q.shape
seq_len_kv = kv.shape[0]
......@@ -238,20 +225,19 @@ def mqa_attn_return_logits_interface(q,
return logits
def ref_fp8_mqa_logits(q: torch.Tensor, kv: torch.Tensor, weights: torch.Tensor,
cu_seqlen_ks: torch.Tensor, cu_seqlen_ke: torch.Tensor):
def ref_fp8_mqa_logits(q: torch.Tensor, kv: torch.Tensor, weights: torch.Tensor, cu_seqlen_ks: torch.Tensor, cu_seqlen_ke: torch.Tensor):
k = kv
q = q.float()
k = k.float()
seq_len_kv = kv.shape[0]
mask_lo = torch.arange(0, seq_len_kv, device='cuda')[None, :] >= cu_seqlen_ks[:, None]
mask_hi = torch.arange(0, seq_len_kv, device='cuda')[None, :] < cu_seqlen_ke[:, None]
mask_lo = torch.arange(0, seq_len_kv, device="cuda")[None, :] >= cu_seqlen_ks[:, None]
mask_hi = torch.arange(0, seq_len_kv, device="cuda")[None, :] < cu_seqlen_ke[:, None]
mask = mask_lo & mask_hi
score = torch.einsum('mhd,nd->hmn', q, k)
score = torch.einsum("mhd,nd->hmn", q, k)
logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0)
logits = logits.masked_fill(~mask, float('-inf'))
logits = logits.masked_fill(~mask, float("-inf"))
cost = mask.sum()
return logits, cost
......@@ -265,32 +251,22 @@ def test_fp8_lighting_indexer(S=4096, SKV=8192, H=32, HKV=1, D=64, kv_stride=1):
weights = torch.randn(S, H, device="cuda", dtype=torch.float32)
p = (torch.randn(S, SKV, device="cuda", dtype=torch.float32) * 4).softmax(dim=-1)
ks, ke = generate_random_cu_seqlens(
per_cp_seqlen=S, cp_size=4, cp_rank=3, kv_stride=kv_stride, average_q_len=2048)
ks, ke = generate_random_cu_seqlens(per_cp_seqlen=S, cp_size=4, cp_rank=3, kv_stride=kv_stride, average_q_len=2048)
logits_ref, cost_ref = ref_fp8_mqa_logits(
q=q, kv=kv, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke)
logits_ref, cost_ref = ref_fp8_mqa_logits(q=q, kv=kv, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke)
q_fp8 = q.to(torch.float8_e4m3fn)
kv_fp8, kv_scales = per_custom_dims_cast_to_fp8(kv, (0,), False)
logits_tl = mqa_attn_return_logits_interface(
q=q_fp8, kv=kv_fp8, kv_scales=kv_scales, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke)
diff = validate_tensor_match(
logits_ref, logits_tl, tolerance=1e-14, tensor_name="logits", should_raise=False)
logits_tl = mqa_attn_return_logits_interface(q=q_fp8, kv=kv_fp8, kv_scales=kv_scales, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke)
diff = validate_tensor_match(logits_ref, logits_tl, tolerance=1e-14, tensor_name="logits", should_raise=False)
print(f"diff: {diff}")
from tilelang.profiler import do_bench
def logits_fn():
return mqa_attn_return_logits_interface(
q=q_fp8,
kv=kv_fp8,
kv_scales=kv_scales,
weights=weights,
cu_seqlen_ks=ks,
cu_seqlen_ke=ke)
return mqa_attn_return_logits_interface(q=q_fp8, kv=kv_fp8, kv_scales=kv_scales, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke)
with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof:
logits_fn()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment