Unverified Commit 29051439 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Lint] Phaseout Yapf format and embrace ruff format (#1417)

parent e84b24bc
...@@ -22,9 +22,9 @@ def preprocess( ...@@ -22,9 +22,9 @@ def preprocess(
@T.prim_func @T.prim_func
def preprocess_kernel( def preprocess_kernel(
O: T.Tensor(shape, dtype), O: T.Tensor(shape, dtype),
dO: T.Tensor(shape, dtype), dO: T.Tensor(shape, dtype),
Delta: T.Tensor([B, S, H], accum_dtype), Delta: T.Tensor([B, S, H], accum_dtype),
): ):
with T.Kernel(H, T.ceildiv(S, block_ND), B) as (bx, by, bz): with T.Kernel(H, T.ceildiv(S, block_ND), B) as (bx, by, bz):
o = T.alloc_fragment([block_ND, block_ND], accum_dtype) o = T.alloc_fragment([block_ND, block_ND], accum_dtype)
...@@ -33,16 +33,12 @@ def preprocess( ...@@ -33,16 +33,12 @@ def preprocess(
acc = T.alloc_fragment([block_ND, block_ND], accum_dtype) acc = T.alloc_fragment([block_ND, block_ND], accum_dtype)
T.clear(acc) T.clear(acc)
for k in T.Pipelined(T.ceildiv(D, block_ND), num_stages=num_stages): for k in T.Pipelined(T.ceildiv(D, block_ND), num_stages=num_stages):
T.copy( T.copy(O[bz, by * block_ND : (by + 1) * block_ND, bx, k * block_ND : (k + 1) * block_ND], o)
O[bz, by * block_ND:(by + 1) * block_ND, bx, k * block_ND:(k + 1) * block_ND], T.copy(dO[bz, by * block_ND : (by + 1) * block_ND, bx, k * block_ND : (k + 1) * block_ND], do)
o)
T.copy(
dO[bz, by * block_ND:(by + 1) * block_ND, bx, k * block_ND:(k + 1) * block_ND],
do)
for i, j in T.Parallel(block_ND, block_ND): for i, j in T.Parallel(block_ND, block_ND):
acc[i, j] += o[i, j] * do[i, j] acc[i, j] += o[i, j] * do[i, j]
T.reduce_sum(acc, delta, 1) T.reduce_sum(acc, delta, 1)
T.copy(delta, Delta[bz, by * block_ND:(by + 1) * block_ND, bx]) T.copy(delta, Delta[bz, by * block_ND : (by + 1) * block_ND, bx])
return preprocess_kernel return preprocess_kernel
...@@ -65,13 +61,13 @@ def postprocess( ...@@ -65,13 +61,13 @@ def postprocess(
@T.prim_func @T.prim_func
def postprocess_kernel( def postprocess_kernel(
dKV: T.Tensor(dkv_shape, accum_dtype), dKV: T.Tensor(dkv_shape, accum_dtype),
dKV_out: T.Tensor(dkv_shape, dtype), dKV_out: T.Tensor(dkv_shape, dtype),
): ):
with T.Kernel(T.ceildiv(S_kv, block_N), kv_group, B, threads=threads) as (bx, by, bz): with T.Kernel(T.ceildiv(S_kv, block_N), kv_group, B, threads=threads) as (bx, by, bz):
T.copy( T.copy(
dKV[bz, bx * block_N:(bx + 1) * block_N, by, :], dKV[bz, bx * block_N : (bx + 1) * block_N, by, :],
dKV_out[bz, bx * block_N:(bx + 1) * block_N, by, :], dKV_out[bz, bx * block_N : (bx + 1) * block_N, by, :],
) )
return postprocess_kernel return postprocess_kernel
...@@ -83,7 +79,8 @@ def postprocess( ...@@ -83,7 +79,8 @@ def postprocess(
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE: True, tilelang.PassConfigKey.TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE: True,
}) },
)
def bwd( def bwd(
B, B,
S, S,
...@@ -102,14 +99,14 @@ def bwd( ...@@ -102,14 +99,14 @@ def bwd(
dtype="bfloat16", dtype="bfloat16",
accum_dtype="float", accum_dtype="float",
): ):
assert is_causal == True, 'non-casual is not supported now' assert is_causal == True, "non-casual is not supported now"
assert topk % block_size == 0, 'otherwise will load some index=0 thus causing wrong kv to be loaded' assert topk % block_size == 0, "otherwise will load some index=0 thus causing wrong kv to be loaded"
assert dtype == "bfloat16" assert dtype == "bfloat16"
assert accum_dtype == "float" assert accum_dtype == "float"
assert indices_dtype == "int32" assert indices_dtype == "int32"
if sm_scale is None: if sm_scale is None:
sm_scale = (D + D_tail)**(-0.5) sm_scale = (D + D_tail) ** (-0.5)
sm_scale_mul_reciprocal_log2 = sm_scale * 1.44269504 # log2(e) sm_scale_mul_reciprocal_log2 = sm_scale * 1.44269504 # log2(e)
H_kv = H // kv_group H_kv = H // kv_group
...@@ -132,14 +129,14 @@ def bwd( ...@@ -132,14 +129,14 @@ def bwd(
@T.prim_func @T.prim_func
def sparse_mla_bwd_kernel( def sparse_mla_bwd_kernel(
Q: T.Tensor(q_shape, dtype), Q: T.Tensor(q_shape, dtype),
KV: T.Tensor(k_shape, dtype), KV: T.Tensor(k_shape, dtype),
dO: T.Tensor(o_shape, dtype), dO: T.Tensor(o_shape, dtype),
Indices: T.Tensor(indices_shape, indices_dtype), Indices: T.Tensor(indices_shape, indices_dtype),
Lse: T.Tensor(lse_shape, accum_dtype), Lse: T.Tensor(lse_shape, accum_dtype),
Delta: T.Tensor(delta_shape, accum_dtype), Delta: T.Tensor(delta_shape, accum_dtype),
dQ: T.Tensor(q_shape, dtype), dQ: T.Tensor(q_shape, dtype),
dKV: T.Tensor(k_shape, accum_dtype), dKV: T.Tensor(k_shape, accum_dtype),
): ):
with T.Kernel(S, B, kv_group, threads=threads) as (s_i, by, bz): with T.Kernel(S, B, kv_group, threads=threads) as (s_i, by, bz):
Q_shared = T.alloc_shared([padded_H, D], dtype) Q_shared = T.alloc_shared([padded_H, D], dtype)
...@@ -165,17 +162,19 @@ def bwd( ...@@ -165,17 +162,19 @@ def bwd(
max_kv_i = s_i max_kv_i = s_i
T.copy(Q[by, s_i, bz * padded_H:(bz + 1) * padded_H, :D], Q_shared) T.copy(Q[by, s_i, bz * padded_H : (bz + 1) * padded_H, :D], Q_shared)
T.copy(Q[by, s_i, bz * padded_H:(bz + 1) * padded_H, D:], Q_tail_shared) T.copy(Q[by, s_i, bz * padded_H : (bz + 1) * padded_H, D:], Q_tail_shared)
T.copy(dO[by, s_i, bz * padded_H:(bz + 1) * padded_H, :D], dO_shared) T.copy(dO[by, s_i, bz * padded_H : (bz + 1) * padded_H, :D], dO_shared)
T.clear(acc_dq) T.clear(acc_dq)
T.clear(acc_dq_tail) T.clear(acc_dq_tail)
T.annotate_layout({ T.annotate_layout(
dQ_shared: tilelang.layout.make_swizzled_layout(dQ_shared), {
dQ_tail_shared: tilelang.layout.make_swizzled_layout(dQ_tail_shared), dQ_shared: tilelang.layout.make_swizzled_layout(dQ_shared),
}) dQ_tail_shared: tilelang.layout.make_swizzled_layout(dQ_tail_shared),
}
)
# Process each block of indices # Process each block of indices
for i_i in T.Pipelined(NS, num_stages=num_stages): for i_i in T.Pipelined(NS, num_stages=num_stages):
...@@ -191,62 +190,31 @@ def bwd( ...@@ -191,62 +190,31 @@ def bwd(
for bi_i, d_i in T.Parallel(BS, D): for bi_i, d_i in T.Parallel(BS, D):
KV_shared[bi_i, d_i] = KV[by, Indices[by, s_i, bz, i_i * BS + bi_i], bz, d_i] KV_shared[bi_i, d_i] = KV[by, Indices[by, s_i, bz, i_i * BS + bi_i], bz, d_i]
T.gemm( T.gemm(Q_shared, KV_shared, acc_p, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
Q_shared, KV_shared, acc_p, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
for bi_i, d_i in T.Parallel(BS, D_tail): for bi_i, d_i in T.Parallel(BS, D_tail):
KV_tail_shared[bi_i, d_i] = KV[by, Indices[by, s_i, bz, i_i * BS + bi_i], bz, KV_tail_shared[bi_i, d_i] = KV[by, Indices[by, s_i, bz, i_i * BS + bi_i], bz, D + d_i]
D + d_i] T.gemm(Q_tail_shared, KV_tail_shared, acc_p, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
T.gemm(
Q_tail_shared,
KV_tail_shared,
acc_p,
transpose_B=True,
policy=T.GemmWarpPolicy.FullCol)
for h_i, bi_i in T.Parallel(padded_H, BS): for h_i, bi_i in T.Parallel(padded_H, BS):
acc_p[h_i, bi_i] = T.exp2(acc_p[h_i, bi_i] * sm_scale_mul_reciprocal_log2 - acc_p[h_i, bi_i] = T.exp2(acc_p[h_i, bi_i] * sm_scale_mul_reciprocal_log2 - Lse[by, s_i, bz * padded_H + h_i])
Lse[by, s_i, bz * padded_H + h_i])
T.copy(acc_p, P_shared_cast) T.copy(acc_p, P_shared_cast)
T.gemm( T.gemm(dO_shared, KV_shared, acc_dp, transpose_B=True, policy=T.GemmWarpPolicy.FullCol, clear_accum=True)
dO_shared,
KV_shared,
acc_dp,
transpose_B=True,
policy=T.GemmWarpPolicy.FullCol,
clear_accum=True)
for h_i, bi_i in T.Parallel(padded_H, BS): for h_i, bi_i in T.Parallel(padded_H, BS):
acc_dp[h_i, bi_i] = acc_p[h_i, bi_i] * ( acc_dp[h_i, bi_i] = acc_p[h_i, bi_i] * (acc_dp[h_i, bi_i] - Delta[by, s_i, bz * padded_H + h_i]) * sm_scale
acc_dp[h_i, bi_i] - Delta[by, s_i, bz * padded_H + h_i]) * sm_scale
T.copy(acc_dp, dP_shared_cast) T.copy(acc_dp, dP_shared_cast)
T.gemm(dP_shared_cast, KV_shared, acc_dq, policy=T.GemmWarpPolicy.FullCol) T.gemm(dP_shared_cast, KV_shared, acc_dq, policy=T.GemmWarpPolicy.FullCol)
T.gemm(dP_shared_cast, KV_tail_shared, acc_dq_tail, policy=T.GemmWarpPolicy.FullCol) T.gemm(dP_shared_cast, KV_tail_shared, acc_dq_tail, policy=T.GemmWarpPolicy.FullCol)
T.gemm( T.gemm(dP_shared_cast, Q_shared, acc_dkv, transpose_A=True, policy=T.GemmWarpPolicy.FullCol, clear_accum=True)
dP_shared_cast, T.gemm(P_shared_cast, dO_shared, acc_dkv, transpose_A=True, policy=T.GemmWarpPolicy.FullCol)
Q_shared,
acc_dkv,
transpose_A=True,
policy=T.GemmWarpPolicy.FullCol,
clear_accum=True)
T.gemm(
P_shared_cast,
dO_shared,
acc_dkv,
transpose_A=True,
policy=T.GemmWarpPolicy.FullCol)
T.clear(acc_dkv_tail) T.clear(acc_dkv_tail)
T.gemm( T.gemm(dP_shared_cast, Q_tail_shared, acc_dkv_tail, transpose_A=True, policy=T.GemmWarpPolicy.FullCol)
dP_shared_cast,
Q_tail_shared,
acc_dkv_tail,
transpose_A=True,
policy=T.GemmWarpPolicy.FullCol)
for s in range(split_store): for s in range(split_store):
for bi_i, d_i in T.Parallel(BS, D): for bi_i, d_i in T.Parallel(BS, D):
...@@ -255,41 +223,32 @@ def bwd( ...@@ -255,41 +223,32 @@ def bwd(
for bi_i, d_i in T.Parallel(BS, D_tail): for bi_i, d_i in T.Parallel(BS, D_tail):
if bi_i < BS // split_store: if bi_i < BS // split_store:
acc_dkv_tail_shared[bi_i, acc_dkv_tail_shared[bi_i, d_i] = acc_dkv_tail[bi_i + s * (BS // split_store), d_i]
d_i] = acc_dkv_tail[bi_i + s * (BS // split_store),
d_i]
for bi_i, d_i in T.Parallel(BS // split_store, D // 4): for bi_i, d_i in T.Parallel(BS // split_store, D // 4):
T.atomic_addx4( T.atomic_addx4(
dKV[by, Indices[by, s_i, bz, i_i * BS + bi_i + s * (BS // split_store)], dKV[by, Indices[by, s_i, bz, i_i * BS + bi_i + s * (BS // split_store)], bz, d_i * 4],
bz, d_i * 4], acc_dkv_shared[bi_i, d_i * 4]) acc_dkv_shared[bi_i, d_i * 4],
)
# Atomically update dKV, dKV_tail tensors # Atomically update dKV, dKV_tail tensors
for bi_i, d_i in T.Parallel(BS // split_store, D_tail // 4): for bi_i, d_i in T.Parallel(BS // split_store, D_tail // 4):
T.atomic_addx4( T.atomic_addx4(
dKV[by, Indices[by, s_i, bz, i_i * BS + bi_i + s * (BS // split_store)], dKV[by, Indices[by, s_i, bz, i_i * BS + bi_i + s * (BS // split_store)], bz, D + d_i * 4],
bz, D + d_i * 4], acc_dkv_tail_shared[bi_i, d_i * 4]) acc_dkv_tail_shared[bi_i, d_i * 4],
)
# Store the accumulated dQ # Store the accumulated dQ
T.copy(acc_dq, dQ_shared) T.copy(acc_dq, dQ_shared)
T.copy(acc_dq_tail, dQ_tail_shared) T.copy(acc_dq_tail, dQ_tail_shared)
T.copy(dQ_shared, dQ[by, s_i, bz * padded_H:(bz + 1) * padded_H, :D]) T.copy(dQ_shared, dQ[by, s_i, bz * padded_H : (bz + 1) * padded_H, :D])
T.copy(dQ_tail_shared, dQ[by, s_i, bz * padded_H:(bz + 1) * padded_H, D:]) T.copy(dQ_tail_shared, dQ[by, s_i, bz * padded_H : (bz + 1) * padded_H, D:])
return sparse_mla_bwd_kernel return sparse_mla_bwd_kernel
def sparse_mla_bwd(q, def sparse_mla_bwd(q, kv, o, do, indices, lse, sm_scale=None, is_casual=True, return_kernel=False, delta=None):
kv,
o,
do,
indices,
lse,
sm_scale=None,
is_casual=True,
return_kernel=False,
delta=None):
assert q.is_contiguous() assert q.is_contiguous()
assert kv.is_contiguous() assert kv.is_contiguous()
assert indices.is_contiguous() assert indices.is_contiguous()
...@@ -322,6 +281,7 @@ def sparse_mla_bwd(q, ...@@ -322,6 +281,7 @@ def sparse_mla_bwd(q,
def ref_sparse_mla_bwd_interface(q, kv, o, do, indices, lse, sm_scale=None, is_casual=True): def ref_sparse_mla_bwd_interface(q, kv, o, do, indices, lse, sm_scale=None, is_casual=True):
from sparse_mla_fwd import ref_sparse_mla_fwd_interface from sparse_mla_fwd import ref_sparse_mla_fwd_interface
q = q.detach().clone() q = q.detach().clone()
kv = kv.detach().clone() kv = kv.detach().clone()
q.requires_grad = True q.requires_grad = True
...@@ -331,30 +291,22 @@ def ref_sparse_mla_bwd_interface(q, kv, o, do, indices, lse, sm_scale=None, is_c ...@@ -331,30 +291,22 @@ def ref_sparse_mla_bwd_interface(q, kv, o, do, indices, lse, sm_scale=None, is_c
return q.grad, kv.grad return q.grad, kv.grad
def test_sparse_mla_bwd(B=1, def test_sparse_mla_bwd(B=1, S=4096, SKV=8192, H=64, HKV=1, DQKV=576, DV=512, topk=2048, dtype=torch.bfloat16, check_correctness=True):
S=4096,
SKV=8192,
H=64,
HKV=1,
DQKV=576,
DV=512,
topk=2048,
dtype=torch.bfloat16,
check_correctness=True):
# Prepare data # Prepare data
q = torch.randn((B, S, H, DQKV), dtype=dtype, device='cuda').requires_grad_(True) q = torch.randn((B, S, H, DQKV), dtype=dtype, device="cuda").requires_grad_(True)
kv = torch.randn((B, SKV, HKV, DQKV), dtype=dtype, device='cuda').requires_grad_(True) kv = torch.randn((B, SKV, HKV, DQKV), dtype=dtype, device="cuda").requires_grad_(True)
do = torch.randn((B, S, H, DV), dtype=dtype, device='cuda') do = torch.randn((B, S, H, DV), dtype=dtype, device="cuda")
indices = torch.full((B, S, HKV, topk), SKV, dtype=torch.int32, device='cuda') indices = torch.full((B, S, HKV, topk), SKV, dtype=torch.int32, device="cuda")
for b in range(B): for b in range(B):
for t in range(S): for t in range(S):
for h in range(HKV): for h in range(HKV):
i_i = torch.randperm(max(1, t))[:topk] i_i = torch.randperm(max(1, t))[:topk]
indices[b, t, h, :len(i_i)] = i_i indices[b, t, h, : len(i_i)] = i_i
# Forward # Forward
from sparse_mla_fwd import sparse_mla_fwd_interface from sparse_mla_fwd import sparse_mla_fwd_interface
tl_out, tl_lse = sparse_mla_fwd_interface(q, kv, indices) tl_out, tl_lse = sparse_mla_fwd_interface(q, kv, indices)
tl_dq, tl_dkv = sparse_mla_bwd(q, kv, tl_out, do, indices, tl_lse) tl_dq, tl_dkv = sparse_mla_bwd(q, kv, tl_out, do, indices, tl_lse)
...@@ -365,13 +317,15 @@ def test_sparse_mla_bwd(B=1, ...@@ -365,13 +317,15 @@ def test_sparse_mla_bwd(B=1,
assert_tensors_similar(tl_dkv, ref_dkv, eps=1e-4, name="dkv") assert_tensors_similar(tl_dkv, ref_dkv, eps=1e-4, name="dkv")
print("assert_tensors_similar passed") print("assert_tensors_similar passed")
per_token_flop = 2 * sum([ per_token_flop = 2 * sum(
H * DV * topk, [
H * DQKV * topk, H * DV * topk,
H * DQKV * topk, H * DQKV * topk,
H * DQKV * topk, H * DQKV * topk,
H * DV * topk, H * DQKV * topk,
]) H * DV * topk,
]
)
from tilelang.profiler import do_bench from tilelang.profiler import do_bench
def fn(): def fn():
...@@ -379,20 +333,9 @@ def test_sparse_mla_bwd(B=1, ...@@ -379,20 +333,9 @@ def test_sparse_mla_bwd(B=1,
ms = do_bench(fn, rep=100, warmup=250) ms = do_bench(fn, rep=100, warmup=250)
print(f"Average time: {ms:.3f} ms") print(f"Average time: {ms:.3f} ms")
print(f'bwd io bandwidth = ', print(f"bwd io bandwidth = ", (B * S * max(DQKV * 2, DQKV + DV) * topk * 2) / (ms * 1e-3) / 1e12)
(B * S * max(DQKV * 2, DQKV + DV) * topk * 2) / (ms * 1e-3) / 1e12) print(f"bwd tflops = ", per_token_flop * S / (ms * 1e-3) / 1e12)
print(f'bwd tflops = ', per_token_flop * S / (ms * 1e-3) / 1e12)
if __name__ == "__main__": if __name__ == "__main__":
test_sparse_mla_bwd( test_sparse_mla_bwd(B=1, S=4096, SKV=8192, H=64, HKV=1, DQKV=576, DV=512, topk=2048, dtype=torch.bfloat16, check_correctness=True)
B=1,
S=4096,
SKV=8192,
H=64,
HKV=1,
DQKV=576,
DV=512,
topk=2048,
dtype=torch.bfloat16,
check_correctness=True)
...@@ -25,15 +25,12 @@ def sparse_mla_fwd( ...@@ -25,15 +25,12 @@ def sparse_mla_fwd(
num_stages=2, num_stages=2,
threads=256, threads=256,
): ):
assert dim == tilelang.math.next_power_of_2( assert dim == tilelang.math.next_power_of_2(dim), f"haven't check padding correctness yet, dim={dim}"
dim), f"haven't check padding correctness yet, dim={dim}" assert tail_dim == tilelang.math.next_power_of_2(tail_dim), f"haven't check padding correctness yet, dim={tail_dim}"
assert tail_dim == tilelang.math.next_power_of_2(
tail_dim), f"haven't check padding correctness yet, dim={tail_dim}"
assert is_causal == True, "non-casual is not supported" assert is_causal == True, "non-casual is not supported"
assert (topk % assert topk % block_I == 0, "otherwise will load some index=0 thus causing wrong kv to be loaded"
block_I == 0), "otherwise will load some index=0 thus causing wrong kv to be loaded"
if sm_scale is None: if sm_scale is None:
sm_scale = (1.0 / (dim + tail_dim))**0.5 * 1.44269504 # log2(e) sm_scale = (1.0 / (dim + tail_dim)) ** 0.5 * 1.44269504 # log2(e)
else: else:
sm_scale = sm_scale * 1.44269504 # log2(e) sm_scale = sm_scale * 1.44269504 # log2(e)
...@@ -55,9 +52,9 @@ def sparse_mla_fwd( ...@@ -55,9 +52,9 @@ def sparse_mla_fwd(
H = head_kv H = head_kv
padded_H = max(tilelang.math.next_power_of_2(head_kv), 16) padded_H = max(tilelang.math.next_power_of_2(head_kv), 16)
if padded_H != H: if padded_H != H:
assert ( assert kv_group == 1, (
kv_group == 1 "here we solve the H padding automatically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automatically)"
), "here we solve the H padding automatically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automatically)" )
BI = block_I BI = block_I
NI = tilelang.cdiv(topk, block_I) NI = tilelang.cdiv(topk, block_I)
D = dim D = dim
...@@ -73,18 +70,17 @@ def sparse_mla_fwd( ...@@ -73,18 +70,17 @@ def sparse_mla_fwd(
@T.prim_func @T.prim_func
def main( def main(
Q: T.Tensor(q_shape, dtype), # type: ignore Q: T.Tensor(q_shape, dtype), # type: ignore
KV: T.Tensor(kv_shape, dtype), # type: ignore KV: T.Tensor(kv_shape, dtype), # type: ignore
Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore
Output: T.Tensor(o_shape, dtype), # type: ignore Output: T.Tensor(o_shape, dtype), # type: ignore
Lse: T.Tensor(lse_shape, accum_dtype), # type: ignore Lse: T.Tensor(lse_shape, accum_dtype), # type: ignore
): ):
with T.Kernel( with T.Kernel(seq_len * REPLICATE_H, batch, kv_group, threads=threads) as (
seq_len * REPLICATE_H, batch, kv_group, threads=threads) as ( bx,
bx, by,
by, bz,
bz, ):
):
Q_shared = T.alloc_shared([H_per_block, D], dtype) Q_shared = T.alloc_shared([H_per_block, D], dtype)
Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype) Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype)
KV_shared = T.alloc_shared([BI, D], dtype) KV_shared = T.alloc_shared([BI, D], dtype)
...@@ -118,16 +114,13 @@ def sparse_mla_fwd( ...@@ -118,16 +114,13 @@ def sparse_mla_fwd(
T.copy(Q[b_i, s_i, H0:H1, D:], Q_tail_shared) T.copy(Q[b_i, s_i, H0:H1, D:], Q_tail_shared)
for i_i in T.Pipelined(NI, num_stages=num_stages): for i_i in T.Pipelined(NI, num_stages=num_stages):
for bi_i in T.Parallel(BI): for bi_i in T.Parallel(BI):
mask[bi_i] = Indices[b_i, s_i, g_i, i_i * BI + bi_i] <= max_kv_i mask[bi_i] = Indices[b_i, s_i, g_i, i_i * BI + bi_i] <= max_kv_i
for bi_i, d_i in T.Parallel(BI, D): for bi_i, d_i in T.Parallel(BI, D):
KV_shared[bi_i, d_i] = KV[b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, KV_shared[bi_i, d_i] = KV[b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, d_i]
d_i]
for bi_i, d_i in T.Parallel(BI, D_tail): for bi_i, d_i in T.Parallel(BI, D_tail):
K_tail_shared[bi_i, d_i] = KV[b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, K_tail_shared[bi_i, d_i] = KV[b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, D + d_i]
D + d_i]
for h_i, bi_i in T.Parallel(H_per_block, BI): for h_i, bi_i in T.Parallel(H_per_block, BI):
acc_s[h_i, bi_i] = T.if_then_else(mask[bi_i], 0, -T.infinity(acc_s.dtype)) acc_s[h_i, bi_i] = T.if_then_else(mask[bi_i], 0, -T.infinity(acc_s.dtype))
...@@ -176,15 +169,7 @@ def sparse_mla_fwd( ...@@ -176,15 +169,7 @@ def sparse_mla_fwd(
return main return main
def sparse_mla_fwd_interface(q, def sparse_mla_fwd_interface(q, kv, indices, sm_scale=None, return_p_sum: bool = False, d_v=512, block_I=64, num_stages=2, threads=256):
kv,
indices,
sm_scale=None,
return_p_sum: bool = False,
d_v=512,
block_I=64,
num_stages=2,
threads=256):
is_casual = True is_casual = True
assert return_p_sum == False, "This kernel file is for fwd only" assert return_p_sum == False, "This kernel file is for fwd only"
assert q.is_contiguous() and kv.is_contiguous() and indices.is_contiguous() assert q.is_contiguous() and kv.is_contiguous() and indices.is_contiguous()
...@@ -201,16 +186,8 @@ def sparse_mla_fwd_interface(q, ...@@ -201,16 +186,8 @@ def sparse_mla_fwd_interface(q,
assert indices.shape == (batch, seq_len, kv_group, topk) assert indices.shape == (batch, seq_len, kv_group, topk)
kernel = sparse_mla_fwd( kernel = sparse_mla_fwd(
heads, heads, dim, tail_dim, topk, kv_group, sm_scale, is_casual, block_I=block_I, num_stages=num_stages, threads=threads
dim, )
tail_dim,
topk,
kv_group,
sm_scale,
is_casual,
block_I=block_I,
num_stages=num_stages,
threads=threads)
out, lse = kernel(q, kv, indices) out, lse = kernel(q, kv, indices)
return out, lse return out, lse
...@@ -230,14 +207,14 @@ def ref_sparse_mla_fwd_interface(q, kv, indices, sm_scale=None, is_casual=True): ...@@ -230,14 +207,14 @@ def ref_sparse_mla_fwd_interface(q, kv, indices, sm_scale=None, is_casual=True):
b, _, _, dim_v = v.shape b, _, _, dim_v = v.shape
g_index = g g_index = g
h_index = h // g h_index = h // g
compressed_casual_mask = torch.arange( compressed_casual_mask = torch.arange(0, sq, dtype=torch.int32, device="cuda").view(-1, 1) >= torch.arange(
0, sq, dtype=torch.int32, device="cuda").view(-1, 1) >= torch.arange( 1 - 1, sk * 1, 1, dtype=torch.int32, device="cuda"
1 - 1, sk * 1, 1, dtype=torch.int32, device="cuda").view(1, -1) ).view(1, -1)
mask = q.new_zeros(b, g_index, sq, sk + 1, dtype=torch.bool).scatter(3, indices.long(), 1) mask = q.new_zeros(b, g_index, sq, sk + 1, dtype=torch.bool).scatter(3, indices.long(), 1)
mask = mask[..., :-1] mask = mask[..., :-1]
mask = mask & compressed_casual_mask.view(1, 1, sq, sk) mask = mask & compressed_casual_mask.view(1, 1, sq, sk)
mask[:, :, :1 - 1, 0] = True mask[:, :, : 1 - 1, 0] = True
mask = mask.view(b, g_index, 1, sq, sk) mask = mask.view(b, g_index, 1, sq, sk)
q = q.view(b, sq, g, -1, dim_q) q = q.view(b, sq, g, -1, dim_q)
...@@ -252,19 +229,21 @@ def ref_sparse_mla_fwd_interface(q, kv, indices, sm_scale=None, is_casual=True): ...@@ -252,19 +229,21 @@ def ref_sparse_mla_fwd_interface(q, kv, indices, sm_scale=None, is_casual=True):
return o.to(torch.bfloat16) return o.to(torch.bfloat16)
def test_sparse_mla_fwd(B=1, def test_sparse_mla_fwd(
S=4096, B=1,
SKV=8192, S=4096,
H=128, SKV=8192,
HKV=1, H=128,
DQK=576, HKV=1,
DV=512, DQK=576,
topk=2048, DV=512,
dtype=torch.bfloat16, topk=2048,
check_correctness=True, dtype=torch.bfloat16,
block_I=64, check_correctness=True,
num_stages=2, block_I=64,
threads=256): num_stages=2,
threads=256,
):
torch.random.manual_seed(0) torch.random.manual_seed(0)
q = torch.randn((B, S, H, DQK), dtype=dtype, device="cuda").requires_grad_(True) q = torch.randn((B, S, H, DQK), dtype=dtype, device="cuda").requires_grad_(True)
kv = torch.randn((B, SKV, HKV, DQK), dtype=dtype, device="cuda").requires_grad_(True) kv = torch.randn((B, SKV, HKV, DQK), dtype=dtype, device="cuda").requires_grad_(True)
...@@ -274,10 +253,9 @@ def test_sparse_mla_fwd(B=1, ...@@ -274,10 +253,9 @@ def test_sparse_mla_fwd(B=1,
for t in range(S): for t in range(S):
for h in range(HKV): for h in range(HKV):
i_i = torch.randperm(max(1, t))[:topk] i_i = torch.randperm(max(1, t))[:topk]
indices[b, t, h, :len(i_i)] = i_i indices[b, t, h, : len(i_i)] = i_i
tl_out, tl_lse = sparse_mla_fwd_interface( tl_out, tl_lse = sparse_mla_fwd_interface(q, kv, indices, block_I=block_I, num_stages=num_stages, threads=threads)
q, kv, indices, block_I=block_I, num_stages=num_stages, threads=threads)
if check_correctness: if check_correctness:
# otherwise may cause out of memory # otherwise may cause out of memory
...@@ -286,8 +264,7 @@ def test_sparse_mla_fwd(B=1, ...@@ -286,8 +264,7 @@ def test_sparse_mla_fwd(B=1,
print("assert_tensors_similar passed") print("assert_tensors_similar passed")
def fn(): def fn():
return sparse_mla_fwd_interface( return sparse_mla_fwd_interface(q, kv, indices, block_I=block_I, num_stages=num_stages, threads=threads)
q, kv, indices, block_I=block_I, num_stages=num_stages, threads=threads)
from tilelang.profiler import do_bench from tilelang.profiler import do_bench
...@@ -315,4 +292,5 @@ if __name__ == "__main__": ...@@ -315,4 +292,5 @@ if __name__ == "__main__":
check_correctness=True, check_correctness=True,
block_I=64, block_I=64,
num_stages=2, num_stages=2,
threads=256) threads=256,
)
...@@ -9,10 +9,16 @@ import argparse ...@@ -9,10 +9,16 @@ import argparse
@tilelang.jit( @tilelang.jit(
out_idx=[-2, -1], out_idx=[-2, -1],
compile_flags=[ compile_flags=[
"-O3", "-Wno-deprecated-declarations", "-U__CUDA_NO_HALF_OPERATORS__", "-O3",
"-U__CUDA_NO_HALF_CONVERSIONS__", "-U__CUDA_NO_HALF2_OPERATORS__", "-Wno-deprecated-declarations",
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__", "--expt-relaxed-constexpr", "--expt-extended-lambda", "-U__CUDA_NO_HALF_OPERATORS__",
"--ptxas-options=-v,--register-usage-level=10", "-DNDEBUG" "-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_HALF2_OPERATORS__",
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"--ptxas-options=-v,--register-usage-level=10",
"-DNDEBUG",
], ],
) )
def sparse_mla_fwd( def sparse_mla_fwd(
...@@ -32,14 +38,12 @@ def sparse_mla_fwd( ...@@ -32,14 +38,12 @@ def sparse_mla_fwd(
num_stages=0, num_stages=0,
threads=384, threads=384,
): ):
assert dim == tilelang.math.next_power_of_2( assert dim == tilelang.math.next_power_of_2(dim), f"haven't check padding correctness yet, dim={dim}"
dim), f"haven't check padding correctness yet, dim={dim}" assert tail_dim == tilelang.math.next_power_of_2(tail_dim), f"haven't check padding correctness yet, dim={tail_dim}"
assert tail_dim == tilelang.math.next_power_of_2( assert is_causal == True, "non-casual is not supported"
tail_dim), f"haven't check padding correctness yet, dim={tail_dim}" assert topk % block_I == 0, "otherwise will load some index=0 thus causing wrong kv to be loaded"
assert is_causal == True, 'non-casual is not supported'
assert topk % block_I == 0, 'otherwise will load some index=0 thus causing wrong kv to be loaded'
if sm_scale is None: if sm_scale is None:
sm_scale = (1.0 / (dim + tail_dim))**0.5 * 1.44269504 # log2(e) sm_scale = (1.0 / (dim + tail_dim)) ** 0.5 * 1.44269504 # log2(e)
else: else:
sm_scale = sm_scale * 1.44269504 # log2(e) sm_scale = sm_scale * 1.44269504 # log2(e)
...@@ -57,15 +61,17 @@ def sparse_mla_fwd( ...@@ -57,15 +61,17 @@ def sparse_mla_fwd(
H = head_kv H = head_kv
padded_H = max(tilelang.math.next_power_of_2(head_kv), 16) padded_H = max(tilelang.math.next_power_of_2(head_kv), 16)
if padded_H != H: if padded_H != H:
assert kv_group == 1, 'here we solve the H padding automatically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automatically)' assert kv_group == 1, (
"here we solve the H padding automatically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automatically)"
)
BI = block_I BI = block_I
NI = tilelang.cdiv(topk, block_I) NI = tilelang.cdiv(topk, block_I)
assert NI % 2 == 0, 'NI should be a multiple of 2' assert NI % 2 == 0, "NI should be a multiple of 2"
D = dim D = dim
D_tail = tail_dim D_tail = tail_dim
KV_stride = kv_stride KV_stride = kv_stride
if head_kv > 64: if head_kv > 64:
assert head_kv % 64 == 0, 'head_kv should be a multiple of 64' assert head_kv % 64 == 0, "head_kv should be a multiple of 64"
REPLICATE_H = head_kv // 64 REPLICATE_H = head_kv // 64
else: else:
REPLICATE_H = 1 REPLICATE_H = 1
...@@ -74,18 +80,14 @@ def sparse_mla_fwd( ...@@ -74,18 +80,14 @@ def sparse_mla_fwd(
@T.prim_func @T.prim_func
def main( def main(
Q: T.Tensor(q_shape, dtype), # type: ignore Q: T.Tensor(q_shape, dtype), # type: ignore
KV: T.Tensor(kv_shape, dtype), # type: ignore KV: T.Tensor(kv_shape, dtype), # type: ignore
Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore
q_start_index_s: T.Tensor(1, indices_dtype), q_start_index_s: T.Tensor(1, indices_dtype),
Output: T.Tensor(o_shape, dtype), # type: ignore Output: T.Tensor(o_shape, dtype), # type: ignore
Lse: T.Tensor(lse_shape, accum_dtype), # type: ignore Lse: T.Tensor(lse_shape, accum_dtype), # type: ignore
): ):
with T.Kernel( with T.Kernel((seq_len - kv_stride + 1 if CP0 else seq_len) * REPLICATE_H, batch, kv_group, threads=threads) as (bx, by, bz):
(seq_len - kv_stride + 1 if CP0 else seq_len) * REPLICATE_H,
batch,
kv_group,
threads=threads) as (bx, by, bz):
Q_shared_l = T.alloc_shared([H_per_block, D // 2], dtype) Q_shared_l = T.alloc_shared([H_per_block, D // 2], dtype)
Q_shared_r = T.alloc_shared([H_per_block, D // 2], dtype) Q_shared_r = T.alloc_shared([H_per_block, D // 2], dtype)
Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype) Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype)
...@@ -122,8 +124,7 @@ def sparse_mla_fwd( ...@@ -122,8 +124,7 @@ def sparse_mla_fwd(
bar_sScale_and_sS_free = T.alloc_barrier(arrive_count=256) bar_sScale_and_sS_free = T.alloc_barrier(arrive_count=256)
b_i, g_i = by, bz b_i, g_i = by, bz
s_i = (bx + (KV_stride - 1 if CP0 else 0)) if REPLICATE_H == 1 else ( s_i = (bx + (KV_stride - 1 if CP0 else 0)) if REPLICATE_H == 1 else (bx // REPLICATE_H + (KV_stride - 1 if CP0 else 0))
bx // REPLICATE_H + (KV_stride - 1 if CP0 else 0))
q_i = q_start_index_s[0] + s_i q_i = q_start_index_s[0] + s_i
max_kv_i = (q_i + 1 - KV_stride) // KV_stride max_kv_i = (q_i + 1 - KV_stride) // KV_stride
...@@ -132,26 +133,24 @@ def sparse_mla_fwd( ...@@ -132,26 +133,24 @@ def sparse_mla_fwd(
tx = T.get_thread_binding() tx = T.get_thread_binding()
T.copy(Q[b_i, s_i, H0:H1, 0:D // 2], Q_shared_l) T.copy(Q[b_i, s_i, H0:H1, 0 : D // 2], Q_shared_l)
T.copy(Q[b_i, s_i, H0:H1, D // 2:D], Q_shared_r) T.copy(Q[b_i, s_i, H0:H1, D // 2 : D], Q_shared_r)
T.copy(Q[b_i, s_i, H0:H1, D:], Q_tail_shared) T.copy(Q[b_i, s_i, H0:H1, D:], Q_tail_shared)
T.barrier_arrive(bar_q) T.barrier_arrive(bar_q)
if tx < 128: if tx < 128:
T.set_max_nreg(240, 1) T.set_max_nreg(240, 1)
T.fill(sumexp, 0) T.fill(sumexp, 0)
T.fill(m_i, -2**30) # avoid -inf - inf to cause nan T.fill(m_i, -(2**30)) # avoid -inf - inf to cause nan
T.fill(acc_o_l, 0) T.fill(acc_o_l, 0)
T.barrier_wait(bar_q, 0) T.barrier_wait(bar_q, 0)
for i_i in T.serial(T.ceildiv(NI, 2)): for i_i in T.serial(T.ceildiv(NI, 2)):
# Buffer 0 # Buffer 0
T.barrier_wait(bar_k_0_ready[0], (i_i & 1)) T.barrier_wait(bar_k_0_ready[0], (i_i & 1))
for h_i, bi_i in T.Parallel(H_per_block, BI): for h_i, bi_i in T.Parallel(H_per_block, BI):
acc_s[h_i, bi_i] = T.if_then_else(is_kv_valid[bi_i], 0, acc_s[h_i, bi_i] = T.if_then_else(is_kv_valid[bi_i], 0, -T.infinity(acc_s.dtype))
-T.infinity(acc_s.dtype))
T.gemm(Q_shared_l, KV_shared_0_l, acc_s, transpose_B=True, wg_wait=-1) T.gemm(Q_shared_l, KV_shared_0_l, acc_s, transpose_B=True, wg_wait=-1)
T.gemm(Q_shared_r, KV_shared_0_r, acc_s, transpose_B=True, wg_wait=-1) T.gemm(Q_shared_r, KV_shared_0_r, acc_s, transpose_B=True, wg_wait=-1)
T.gemm(Q_tail_shared, K_tail_shared_0, acc_s, transpose_B=True, wg_wait=-1) T.gemm(Q_tail_shared, K_tail_shared_0, acc_s, transpose_B=True, wg_wait=-1)
...@@ -187,8 +186,7 @@ def sparse_mla_fwd( ...@@ -187,8 +186,7 @@ def sparse_mla_fwd(
T.barrier_wait(bar_k_1_ready[0], (i_i & 1)) T.barrier_wait(bar_k_1_ready[0], (i_i & 1))
for h_i, bi_i in T.Parallel(H_per_block, BI): for h_i, bi_i in T.Parallel(H_per_block, BI):
acc_s[h_i, bi_i] = T.if_then_else(is_kv_valid[bi_i], 0, acc_s[h_i, bi_i] = T.if_then_else(is_kv_valid[bi_i], 0, -T.infinity(acc_s.dtype))
-T.infinity(acc_s.dtype))
T.gemm(Q_shared_l, KV_shared_1_l, acc_s, transpose_B=True, wg_wait=-1) T.gemm(Q_shared_l, KV_shared_1_l, acc_s, transpose_B=True, wg_wait=-1)
T.gemm(Q_shared_r, KV_shared_1_r, acc_s, transpose_B=True, wg_wait=-1) T.gemm(Q_shared_r, KV_shared_1_r, acc_s, transpose_B=True, wg_wait=-1)
T.gemm(Q_tail_shared, K_tail_shared_1, acc_s, transpose_B=True, wg_wait=-1) T.gemm(Q_tail_shared, K_tail_shared_1, acc_s, transpose_B=True, wg_wait=-1)
...@@ -227,7 +225,7 @@ def sparse_mla_fwd( ...@@ -227,7 +225,7 @@ def sparse_mla_fwd(
for h_i in T.Parallel(H_per_block): for h_i in T.Parallel(H_per_block):
sumexp[h_i] = T.log2(sumexp[h_i]) + m_i[h_i] * sm_scale sumexp[h_i] = T.log2(sumexp[h_i]) + m_i[h_i] * sm_scale
T.copy(acc_o_l, O_shared_l) T.copy(acc_o_l, O_shared_l)
T.copy(O_shared_l, Output[b_i, s_i, H0:H1, 0:D // 2]) T.copy(O_shared_l, Output[b_i, s_i, H0:H1, 0 : D // 2])
elif tx >= 128 and tx < 256: elif tx >= 128 and tx < 256:
T.set_max_nreg(168, 1) T.set_max_nreg(168, 1)
...@@ -257,7 +255,7 @@ def sparse_mla_fwd( ...@@ -257,7 +255,7 @@ def sparse_mla_fwd(
acc_o_r[h_i, d_i] /= sum_exp_shared[h_i] acc_o_r[h_i, d_i] /= sum_exp_shared[h_i]
T.copy(acc_o_r, O_shared_r) T.copy(acc_o_r, O_shared_r)
T.copy(O_shared_r, Output[b_i, s_i, H0:H1, D // 2:D]) T.copy(O_shared_r, Output[b_i, s_i, H0:H1, D // 2 : D])
elif tx >= 256: elif tx >= 256:
# producer # producer
T.set_max_nreg(80, 0) T.set_max_nreg(80, 0)
...@@ -265,70 +263,58 @@ def sparse_mla_fwd( ...@@ -265,70 +263,58 @@ def sparse_mla_fwd(
# Buffer 0 # Buffer 0
T.barrier_wait(bar_k_0_free[0], ((i_i & 1) ^ 1)) T.barrier_wait(bar_k_0_free[0], ((i_i & 1) ^ 1))
for r in T.serial(4): for r in T.serial(4):
indices_local[0] = Indices[b_i, s_i, g_i, indices_local[0] = Indices[b_i, s_i, g_i, (i_i * 2) * BI + r * 16 + (tx - 256) // 8]
(i_i * 2) * BI + r * 16 + (tx - 256) // 8]
is_kv_valid[r * 16 + (tx - 256) // 8] = indices_local[0] <= max_kv_i is_kv_valid[r * 16 + (tx - 256) // 8] = indices_local[0] <= max_kv_i
if is_kv_valid[r * 16 + (tx - 256) // 8]: if is_kv_valid[r * 16 + (tx - 256) // 8]:
with T.attr("default", "async_scope", 1): with T.attr("default", "async_scope", 1):
for u in T.serial(4): for u in T.serial(4):
for v in T.vectorized(8): for v in T.vectorized(8):
KV_shared_0_l[r * 16 + (tx - 256) // 8, KV_shared_0_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[
64 * u + (tx - 256) % 8 * 8 + b_i, indices_local[0], g_i, 64 * u + (tx - 256) % 8 * 8 + v
v] = KV[b_i, indices_local[0], g_i, ]
64 * u + (tx - 256) % 8 * 8 + v] KV_shared_0_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[
KV_shared_0_r[r * 16 + (tx - 256) // 8, b_i, indices_local[0], g_i, D // 2 + 64 * u + (tx - 256) % 8 * 8 + v
64 * u + (tx - 256) % 8 * 8 + ]
v] = KV[b_i, indices_local[0], g_i, D // 2 +
64 * u + (tx - 256) % 8 * 8 + v]
with T.attr("default", "async_scope", 1): with T.attr("default", "async_scope", 1):
for v in T.vectorized(8): for v in T.vectorized(8):
K_tail_shared_0[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + K_tail_shared_0[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v] = KV[
v] = KV[b_i, indices_local[0], g_i, b_i, indices_local[0], g_i, D + (tx - 256) % 8 * 8 + v
D + (tx - 256) % 8 * 8 + v] ]
T.cp_async_barrier_noinc(bar_k_0_ready[0]) T.cp_async_barrier_noinc(bar_k_0_ready[0])
# Buffer 1 # Buffer 1
T.barrier_wait(bar_k_1_free[0], ((i_i & 1) ^ 1)) T.barrier_wait(bar_k_1_free[0], ((i_i & 1) ^ 1))
for r in T.serial(4): for r in T.serial(4):
indices_local[0] = Indices[b_i, s_i, g_i, indices_local[0] = Indices[b_i, s_i, g_i, (i_i * 2 + 1) * BI + r * 16 + (tx - 256) // 8]
(i_i * 2 + 1) * BI + r * 16 + (tx - 256) // 8]
is_kv_valid[r * 16 + (tx - 256) // 8] = indices_local[0] <= max_kv_i is_kv_valid[r * 16 + (tx - 256) // 8] = indices_local[0] <= max_kv_i
if is_kv_valid[r * 16 + (tx - 256) // 8]: if is_kv_valid[r * 16 + (tx - 256) // 8]:
with T.attr("default", "async_scope", 1): with T.attr("default", "async_scope", 1):
for u in T.serial(4): for u in T.serial(4):
for v in T.vectorized(8): for v in T.vectorized(8):
KV_shared_1_l[r * 16 + (tx - 256) // 8, KV_shared_1_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[
64 * u + (tx - 256) % 8 * 8 + b_i, indices_local[0], g_i, 64 * u + (tx - 256) % 8 * 8 + v
v] = KV[b_i, indices_local[0], g_i, ]
64 * u + (tx - 256) % 8 * 8 + v] KV_shared_1_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[
KV_shared_1_r[r * 16 + (tx - 256) // 8, b_i, indices_local[0], g_i, D // 2 + 64 * u + (tx - 256) % 8 * 8 + v
64 * u + (tx - 256) % 8 * 8 + ]
v] = KV[b_i, indices_local[0], g_i, D // 2 +
64 * u + (tx - 256) % 8 * 8 + v]
with T.attr("default", "async_scope", 1): with T.attr("default", "async_scope", 1):
for v in T.vectorized(8): for v in T.vectorized(8):
K_tail_shared_1[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + K_tail_shared_1[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v] = KV[
v] = KV[b_i, indices_local[0], g_i, b_i, indices_local[0], g_i, D + (tx - 256) % 8 * 8 + v
D + (tx - 256) % 8 * 8 + v] ]
T.cp_async_barrier_noinc(bar_k_1_ready[0]) T.cp_async_barrier_noinc(bar_k_1_ready[0])
return main return main
def sparse_mla_fwd_interface(q, def sparse_mla_fwd_interface(
kv, q, kv, indices, q_start_index_s, kv_stride, sm_scale=None, is_casual=True, return_kernel=False, print_kernel=False
indices, ):
q_start_index_s,
kv_stride,
sm_scale=None,
is_casual=True,
return_kernel=False,
print_kernel=False):
assert q.is_contiguous() and kv.is_contiguous() and indices.is_contiguous() assert q.is_contiguous() and kv.is_contiguous() and indices.is_contiguous()
batch, seq_len, heads, dim_plus_tail_dim = q.shape batch, seq_len, heads, dim_plus_tail_dim = q.shape
_, seq_len_kv, kv_group, _ = kv.shape _, seq_len_kv, kv_group, _ = kv.shape
assert dim_plus_tail_dim == 576, 'you should assign dim otherwise' assert dim_plus_tail_dim == 576, "you should assign dim otherwise"
dim = 512 dim = 512
assert kv.shape[-1] == dim_plus_tail_dim assert kv.shape[-1] == dim_plus_tail_dim
...@@ -338,29 +324,23 @@ def sparse_mla_fwd_interface(q, ...@@ -338,29 +324,23 @@ def sparse_mla_fwd_interface(q,
assert indices.shape == (batch, seq_len, kv_group, topk) assert indices.shape == (batch, seq_len, kv_group, topk)
if q_start_index_s != 0: if q_start_index_s != 0:
assert q_start_index_s > kv_stride, "If it is because each cp has too short length, you should fix the logic involving CP0 (cp_rank == 0), to make sure q with pos < KV_Stride - 1 is masked (or you may just ignore how this is handled if nan in these q's Out would not effect others, which is reported to be likely to happen by wangding)" assert q_start_index_s > kv_stride, (
"If it is because each cp has too short length, you should fix the logic involving CP0 (cp_rank == 0), to make sure q with pos < KV_Stride - 1 is masked (or you may just ignore how this is handled if nan in these q's Out would not effect others, which is reported to be likely to happen by wangding)"
)
CP0 = q_start_index_s == 0 CP0 = q_start_index_s == 0
kernel = sparse_mla_fwd(batch, seq_len, seq_len_kv, heads, dim, tail_dim, topk, kv_stride, kernel = sparse_mla_fwd(batch, seq_len, seq_len_kv, heads, dim, tail_dim, topk, kv_stride, kv_group, sm_scale, is_casual, CP0)
kv_group, sm_scale, is_casual, CP0)
if print_kernel: if print_kernel:
print(kernel.get_kernel_source()) print(kernel.get_kernel_source())
out, lse = kernel(q, kv, indices, out, lse = kernel(q, kv, indices, torch.tensor([q_start_index_s], dtype=torch.int32, device="cuda"))
torch.tensor([q_start_index_s], dtype=torch.int32, device="cuda"))
if return_kernel: if return_kernel:
return kernel return kernel
if q_start_index_s == 0 and kv_stride > 1: if q_start_index_s == 0 and kv_stride > 1:
out[:, :kv_stride - 1, :, :] = 0 out[:, : kv_stride - 1, :, :] = 0
return out, lse return out, lse
def ref_sparse_mla_fwd_interface(q, def ref_sparse_mla_fwd_interface(q, kv, indices, q_start_index_s, kv_stride=4, sm_scale=None, is_casual=True):
kv,
indices,
q_start_index_s,
kv_stride=4,
sm_scale=None,
is_casual=True):
q = q.float() q = q.float()
kv = kv.float() kv = kv.float()
indices = indices.transpose(1, 2) indices = indices.transpose(1, 2)
...@@ -369,7 +349,7 @@ def ref_sparse_mla_fwd_interface(q, ...@@ -369,7 +349,7 @@ def ref_sparse_mla_fwd_interface(q,
if q_start_index_s is None: if q_start_index_s is None:
q_start_index_s = sk * kv_stride - sq q_start_index_s = sk * kv_stride - sq
assert kv.shape[-1] == 576, 'you should assign dim otherwise' assert kv.shape[-1] == 576, "you should assign dim otherwise"
dim = 512 dim = 512
k = kv k = kv
v = kv[..., :dim] v = kv[..., :dim]
...@@ -378,15 +358,14 @@ def ref_sparse_mla_fwd_interface(q, ...@@ -378,15 +358,14 @@ def ref_sparse_mla_fwd_interface(q,
num_kv_per_index = 1 num_kv_per_index = 1
g_index = g g_index = g
h_index = h // g h_index = h // g
compressed_casual_mask = torch.arange( compressed_casual_mask = torch.arange(q_start_index_s, sq + q_start_index_s, dtype=torch.int32, device="cuda").view(
q_start_index_s, sq + q_start_index_s, dtype=torch.int32, -1, 1
device="cuda").view(-1, 1) >= torch.arange( ) >= torch.arange(kv_stride - 1, sk * kv_stride, kv_stride, dtype=torch.int32, device="cuda").view(1, -1)
kv_stride - 1, sk * kv_stride, kv_stride, dtype=torch.int32, device="cuda").view(1, -1)
mask = q.new_zeros(b, g_index, sq, sk + 1, dtype=torch.bool).scatter(3, indices.long(), 1) mask = q.new_zeros(b, g_index, sq, sk + 1, dtype=torch.bool).scatter(3, indices.long(), 1)
mask = mask[..., :-1] mask = mask[..., :-1]
mask = mask & compressed_casual_mask.view(1, 1, sq, sk) mask = mask & compressed_casual_mask.view(1, 1, sq, sk)
mask[:, :, :kv_stride - 1, 0] = True mask[:, :, : kv_stride - 1, 0] = True
mask = mask.view(b, g_index, 1, sq, sk) mask = mask.view(b, g_index, 1, sq, sk)
q = q.view(b, sq, g, -1, dim_q) q = q.view(b, sq, g, -1, dim_q)
...@@ -401,41 +380,32 @@ def ref_sparse_mla_fwd_interface(q, ...@@ -401,41 +380,32 @@ def ref_sparse_mla_fwd_interface(q,
return o.to(torch.bfloat16) return o.to(torch.bfloat16)
def test_sparse_mla_fwd_pipelined(B=1, def test_sparse_mla_fwd_pipelined(
S=4096, B=1, S=4096, SKV=8192, H=128, HKV=1, DQK=576, DV=512, topk=2048, dtype=torch.bfloat16, q_start_s_index=1024, check_correctness=True
SKV=8192, ):
H=128,
HKV=1,
DQK=576,
DV=512,
topk=2048,
dtype=torch.bfloat16,
q_start_s_index=1024,
check_correctness=True):
KV_stride = 1 KV_stride = 1
torch.random.manual_seed(0) torch.random.manual_seed(0)
q = torch.randn((B, S, H, DQK), dtype=dtype, device='cuda').requires_grad_(True) / 10 q = torch.randn((B, S, H, DQK), dtype=dtype, device="cuda").requires_grad_(True) / 10
kv = torch.randn((B, SKV, HKV, DQK), dtype=dtype, device='cuda').requires_grad_(True) / 10 kv = torch.randn((B, SKV, HKV, DQK), dtype=dtype, device="cuda").requires_grad_(True) / 10
q_start_s_index_t = torch.tensor([q_start_s_index], dtype=torch.int32, device="cuda") q_start_s_index_t = torch.tensor([q_start_s_index], dtype=torch.int32, device="cuda")
q.clamp_(-10, 10) q.clamp_(-10, 10)
kv.clamp_(-10, 10) kv.clamp_(-10, 10)
indices = torch.full((B, S, HKV, topk), SKV, dtype=torch.int32, device='cuda') indices = torch.full((B, S, HKV, topk), SKV, dtype=torch.int32, device="cuda")
for b in range(B): for b in range(B):
for t in range(S): for t in range(S):
for h in range(HKV): for h in range(HKV):
i_i = torch.randperm(min(max(1, ((t + q_start_s_index) // KV_stride)), SKV))[:topk] i_i = torch.randperm(min(max(1, ((t + q_start_s_index) // KV_stride)), SKV))[:topk]
indices[b, t, h, :len(i_i)] = i_i indices[b, t, h, : len(i_i)] = i_i
kernel = sparse_mla_fwd_interface( kernel = sparse_mla_fwd_interface(q, kv, indices, q_start_s_index, KV_stride, return_kernel=True, print_kernel=True)
q, kv, indices, q_start_s_index, KV_stride, return_kernel=True, print_kernel=True)
def fn(): def fn():
out, lse = kernel(q, kv, indices, q_start_s_index_t) out, lse = kernel(q, kv, indices, q_start_s_index_t)
if q_start_s_index == 0 and KV_stride > 1: if q_start_s_index == 0 and KV_stride > 1:
out[:, :KV_stride - 1, :, :] = 0 out[:, : KV_stride - 1, :, :] = 0
return out, lse return out, lse
tl_out, tl_lse = fn() tl_out, tl_lse = fn()
...@@ -446,14 +416,15 @@ def test_sparse_mla_fwd_pipelined(B=1, ...@@ -446,14 +416,15 @@ def test_sparse_mla_fwd_pipelined(B=1,
torch.testing.assert_close(tl_out, ref_out, rtol=1e-3, atol=1e-3) torch.testing.assert_close(tl_out, ref_out, rtol=1e-3, atol=1e-3)
from tilelang.profiler import do_bench from tilelang.profiler import do_bench
ms = do_bench( ms = do_bench(
fn, fn,
rep=10, rep=10,
warmup=10, warmup=10,
) )
print(f"Average time: {ms:.3f} ms") print(f"Average time: {ms:.3f} ms")
print(f'fwd io bandwidth = ', (B * S * DQK * topk * 2) / (ms * 1e-3) / 1e12) print(f"fwd io bandwidth = ", (B * S * DQK * topk * 2) / (ms * 1e-3) / 1e12)
print(f'fwd tflops = ', (B * S * (DQK + DV) * topk * 2 * H) / (ms * 1e-3) / 1e12) print(f"fwd tflops = ", (B * S * (DQK + DV) * topk * 2 * H) / (ms * 1e-3) / 1e12)
if __name__ == "__main__": if __name__ == "__main__":
...@@ -464,5 +435,4 @@ if __name__ == "__main__": ...@@ -464,5 +435,4 @@ if __name__ == "__main__":
B, S, SKV, H, HKV, DQK, DV, topk, dtype = 1, 1024, 8192, 128, 1, 576, 512, 2048, torch.bfloat16 B, S, SKV, H, HKV, DQK, DV, topk, dtype = 1, 1024, 8192, 128, 1, 576, 512, 2048, torch.bfloat16
else: else:
B, S, SKV, H, HKV, DQK, DV, topk, dtype = 1, 4096, 8192, 128, 1, 576, 512, 2048, torch.bfloat16 B, S, SKV, H, HKV, DQK, DV, topk, dtype = 1, 4096, 8192, 128, 1, 576, 512, 2048, torch.bfloat16
test_sparse_mla_fwd_pipelined( test_sparse_mla_fwd_pipelined(B, S, SKV, H, HKV, DQK, DV, topk, dtype, check_correctness=args.test_correctness)
B, S, SKV, H, HKV, DQK, DV, topk, dtype, check_correctness=args.test_correctness)
...@@ -21,23 +21,20 @@ def test_example_fp8_lighting_indexer(): ...@@ -21,23 +21,20 @@ def test_example_fp8_lighting_indexer():
@tilelang.testing.requires_cuda_compute_version_ge(9, 0) @tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_sparse_mla_fwd(): def test_example_sparse_mla_fwd():
# small shapes for testing # small shapes for testing
sparse_mla_fwd.test_sparse_mla_fwd( sparse_mla_fwd.test_sparse_mla_fwd(S=256, SKV=1024, H=64, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False)
S=256, SKV=1024, H=64, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False)
@tilelang.testing.requires_cuda @tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0) @tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_sparse_mla_fwd_pipelined(): def test_example_sparse_mla_fwd_pipelined():
# small shapes for testing # small shapes for testing
sparse_mla_fwd_pipelined.test_sparse_mla_fwd_pipelined( sparse_mla_fwd_pipelined.test_sparse_mla_fwd_pipelined(S=256, SKV=512, H=64, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False)
S=256, SKV=512, H=64, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False)
@tilelang.testing.requires_cuda @tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0) @tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_sparse_mla_bwd(): def test_example_sparse_mla_bwd():
sparse_mla_bwd.test_sparse_mla_bwd( sparse_mla_bwd.test_sparse_mla_bwd(S=256, SKV=512, H=64, HKV=1, DQKV=576, DV=512, topk=256, check_correctness=False)
S=256, SKV=512, H=64, HKV=1, DQKV=576, DV=512, topk=256, check_correctness=False)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -127,9 +127,9 @@ def tl_topk_impl(topk, in_dtype="float32", out_dtype="int32"): ...@@ -127,9 +127,9 @@ def tl_topk_impl(topk, in_dtype="float32", out_dtype="int32"):
l_num_input = s_num_input[r_idx] l_num_input = s_num_input[r_idx]
for s in T.serial(T.ceildiv(l_num_input, BLOCK_SIZE)): for s in T.serial(T.ceildiv(l_num_input, BLOCK_SIZE)):
if s * BLOCK_SIZE + tx < l_num_input: if s * BLOCK_SIZE + tx < l_num_input:
l_bin_id32 = T.Cast("int32", (( l_bin_id32 = T.Cast(
convert_to_uint32(input[bx, s_input_idx[r_idx, s * BLOCK_SIZE + tx]]) >> "int32", ((convert_to_uint32(input[bx, s_input_idx[r_idx, s * BLOCK_SIZE + tx]]) >> (24 - round * 8)) & 0xFF)
(24 - round * 8)) & 0xFF)) )
T.atomic_add(s_histogram[l_bin_id32], 1) T.atomic_add(s_histogram[l_bin_id32], 1)
T.sync_threads() T.sync_threads()
# cumsum # cumsum
...@@ -156,23 +156,20 @@ def tl_topk_impl(topk, in_dtype="float32", out_dtype="int32"): ...@@ -156,23 +156,20 @@ def tl_topk_impl(topk, in_dtype="float32", out_dtype="int32"):
for s in T.serial(T.ceildiv(l_num_input, BLOCK_SIZE)): for s in T.serial(T.ceildiv(l_num_input, BLOCK_SIZE)):
T.sync_threads() T.sync_threads()
if s * BLOCK_SIZE + tx < l_num_input: if s * BLOCK_SIZE + tx < l_num_input:
l_bin_id32 = T.Cast("int32", (( l_bin_id32 = T.Cast(
convert_to_uint32(input[bx, s_input_idx[r_idx, s * BLOCK_SIZE + tx]]) >> "int32", ((convert_to_uint32(input[bx, s_input_idx[r_idx, s * BLOCK_SIZE + tx]]) >> (24 - round * 8)) & 0xFF)
(24 - round * 8)) & 0xFF)) )
if l_bin_id32 > l_threshold_bin_id: if l_bin_id32 > l_threshold_bin_id:
pos = T.atomic_add( pos = T.atomic_add(s_histogram[l_bin_id32 + 1], 1, return_prev=True) + l_start_pos
s_histogram[l_bin_id32 + 1], 1, return_prev=True) + l_start_pos
index[bx, pos] = s_input_idx[r_idx, s * BLOCK_SIZE + tx] index[bx, pos] = s_input_idx[r_idx, s * BLOCK_SIZE + tx]
elif l_bin_id32 == l_threshold_bin_id and l_new_topk > 0: elif l_bin_id32 == l_threshold_bin_id and l_new_topk > 0:
if round == 3: if round == 3:
l_out_pos = T.atomic_add( l_out_pos = T.atomic_add(s_histogram[l_bin_id32 + 1], 1, return_prev=True) + l_start_pos
s_histogram[l_bin_id32 + 1], 1, return_prev=True) + l_start_pos
if l_out_pos < topk: if l_out_pos < topk:
index[bx, l_out_pos] = s_input_idx[r_idx, s * BLOCK_SIZE + tx] index[bx, l_out_pos] = s_input_idx[r_idx, s * BLOCK_SIZE + tx]
else: else:
pos = T.atomic_add(s_num_input[r_idx ^ 1], 1, return_prev=True) pos = T.atomic_add(s_num_input[r_idx ^ 1], 1, return_prev=True)
s_input_idx[r_idx ^ 1, pos] = s_input_idx[r_idx, s_input_idx[r_idx ^ 1, pos] = s_input_idx[r_idx, s * BLOCK_SIZE + tx]
s * BLOCK_SIZE + tx]
return tl_topk_kernel return tl_topk_kernel
...@@ -186,7 +183,6 @@ def tl_topk(input, starts, ends, topk): ...@@ -186,7 +183,6 @@ def tl_topk(input, starts, ends, topk):
def test_topk_selector(batch=64, seq_len=32 * 1024, topk=2048): def test_topk_selector(batch=64, seq_len=32 * 1024, topk=2048):
batch = 64 batch = 64
seq_len = 32 * 1024 seq_len = 32 * 1024
topk = 2048 topk = 2048
...@@ -212,8 +208,7 @@ def test_topk_selector(batch=64, seq_len=32 * 1024, topk=2048): ...@@ -212,8 +208,7 @@ def test_topk_selector(batch=64, seq_len=32 * 1024, topk=2048):
set_ref = set(ref_np) set_ref = set(ref_np)
set_trt = set(trt_np) set_trt = set(trt_np)
intersection = set_ref & set_trt intersection = set_ref & set_trt
print("selected/all:", len(intersection), "/", len(set_ref), "=", print("selected/all:", len(intersection), "/", len(set_ref), "=", len(intersection) / len(set_ref))
len(intersection) / len(set_ref))
# Performance test with CUDA events # Performance test with CUDA events
......
...@@ -23,8 +23,7 @@ def _is_equal(a, b): ...@@ -23,8 +23,7 @@ def _is_equal(a, b):
if isinstance(a, torch.Tensor): if isinstance(a, torch.Tensor):
return a is b return a is b
# Whitelist of types that are safe to compare by value for caching. # Whitelist of types that are safe to compare by value for caching.
if isinstance(a, (int, float, str, bool, type(None))) and isinstance( if isinstance(a, (int, float, str, bool, type(None))) and isinstance(b, (int, float, str, bool, type(None))):
b, (int, float, str, bool, type(None))):
return a == b return a == b
# For other types, we cannot guarantee a cheap and safe comparison, so we fail the cache check. # For other types, we cannot guarantee a cheap and safe comparison, so we fail the cache check.
return False return False
...@@ -58,9 +57,11 @@ def tensor_cache(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor] ...@@ -58,9 +57,11 @@ def tensor_cache(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]
if len(args) == len(last_args) and len(kwargs) == len(last_kwargs): if len(args) == len(last_args) and len(kwargs) == len(last_kwargs):
# For Tensors, check for object identity. For other types, check for equality. # For Tensors, check for object identity. For other types, check for equality.
# Python caches small integers, so `is` works for them but not for large integers like 4096. # Python caches small integers, so `is` works for them but not for large integers like 4096.
if all(_is_equal(a, b) for a, b in zip(args, last_args)) and \ if (
set(kwargs.keys()) == set(last_kwargs.keys()) and \ all(_is_equal(a, b) for a, b in zip(args, last_args))
all(_is_equal(v, last_kwargs[k]) for k, v in kwargs.items()): and set(kwargs.keys()) == set(last_kwargs.keys())
and all(_is_equal(v, last_kwargs[k]) for k, v in kwargs.items())
):
return last_result return last_result
result = fn(*args, **kwargs) result = fn(*args, **kwargs)
...@@ -79,73 +80,68 @@ def cal_seq_idx_from_cu_seqlens(cu_seqlens: torch.LongTensor, seq_len: int): ...@@ -79,73 +80,68 @@ def cal_seq_idx_from_cu_seqlens(cu_seqlens: torch.LongTensor, seq_len: int):
@tensor_cache @tensor_cache
def cal_seq_idx_for_q(cu_seqlens_qs: torch.LongTensor, cu_seqlens_qe: torch.LongTensor, def cal_seq_idx_for_q(cu_seqlens_qs: torch.LongTensor, cu_seqlens_qe: torch.LongTensor, seq_len: int) -> torch.IntTensor:
seq_len: int) -> torch.IntTensor: seq_idx_for_q = torch.full((seq_len,), len(cu_seqlens_qs), dtype=torch.int32, device=cu_seqlens_qs.device)
seq_idx_for_q = torch.full((seq_len,),
len(cu_seqlens_qs),
dtype=torch.int32,
device=cu_seqlens_qs.device)
for i in range(len(cu_seqlens_qs)): for i in range(len(cu_seqlens_qs)):
seq_idx_for_q[cu_seqlens_qs[i]:cu_seqlens_qe[i]] = i seq_idx_for_q[cu_seqlens_qs[i] : cu_seqlens_qe[i]] = i
return seq_idx_for_q return seq_idx_for_q
@tensor_cache @tensor_cache
def cal_cu_seqlen_ks_for_q(cu_seqlens_qs: torch.LongTensor, cu_seqlens_qe: torch.LongTensor, def cal_cu_seqlen_ks_for_q(
cu_seqlens_ks: torch.LongTensor, seq_len: int) -> torch.IntTensor: cu_seqlens_qs: torch.LongTensor, cu_seqlens_qe: torch.LongTensor, cu_seqlens_ks: torch.LongTensor, seq_len: int
) -> torch.IntTensor:
cu_seqlen_ks_for_each_q = torch.gather( cu_seqlen_ks_for_each_q = torch.gather(
input=torch.cat([ input=torch.cat([cu_seqlens_ks, torch.full((1,), torch.iinfo(torch.int32).max, dtype=torch.int32, device=cu_seqlens_qs.device)]),
cu_seqlens_ks,
torch.full((1,),
torch.iinfo(torch.int32).max,
dtype=torch.int32,
device=cu_seqlens_qs.device)
]),
dim=0, dim=0,
index=cal_seq_idx_for_q( index=cal_seq_idx_for_q(cu_seqlens_qs=cu_seqlens_qs, cu_seqlens_qe=cu_seqlens_qe, seq_len=seq_len).long(),
cu_seqlens_qs=cu_seqlens_qs, cu_seqlens_qe=cu_seqlens_qe, seq_len=seq_len).long()) )
return cu_seqlen_ks_for_each_q.int() return cu_seqlen_ks_for_each_q.int()
@tensor_cache @tensor_cache
def cal_cu_seqlen_ke_for_q(cu_seqlens_qs: torch.LongTensor, cu_seqlens_qe: torch.LongTensor, def cal_cu_seqlen_ke_for_q(
cu_seqlens_ks: torch.LongTensor, cu_seqlens_ke: torch.LongTensor, cu_seqlens_qs: torch.LongTensor,
q_start_idxs: torch.LongTensor, seq_len: int, cu_seqlens_qe: torch.LongTensor,
kv_stride: int) -> torch.IntTensor: cu_seqlens_ks: torch.LongTensor,
cu_seqlens_ke: torch.LongTensor,
q_start_idxs: torch.LongTensor,
seq_len: int,
kv_stride: int,
) -> torch.IntTensor:
cu_seqlen_ke_for_each_q = torch.gather( cu_seqlen_ke_for_each_q = torch.gather(
input=torch.cat( input=torch.cat([cu_seqlens_ke, torch.zeros(1, dtype=torch.int32, device=cu_seqlens_qs.device)]),
[cu_seqlens_ke,
torch.zeros(1, dtype=torch.int32, device=cu_seqlens_qs.device)]),
dim=0, dim=0,
index=cal_seq_idx_for_q( index=cal_seq_idx_for_q(cu_seqlens_qs=cu_seqlens_qs, cu_seqlens_qe=cu_seqlens_qe, seq_len=seq_len).long(),
cu_seqlens_qs=cu_seqlens_qs, cu_seqlens_qe=cu_seqlens_qe, seq_len=seq_len).long()) )
casual_cu_seqlen_ke_for_each_q = torch.zeros((seq_len,), casual_cu_seqlen_ke_for_each_q = torch.zeros((seq_len,), dtype=torch.int32, device=cu_seqlens_qs.device)
dtype=torch.int32,
device=cu_seqlens_qs.device)
for i in range(len(cu_seqlens_qs)): for i in range(len(cu_seqlens_qs)):
casual_cu_seqlen_ke_for_each_q[cu_seqlens_qs[i]:cu_seqlens_qe[i]] = (torch.arange( casual_cu_seqlen_ke_for_each_q[cu_seqlens_qs[i] : cu_seqlens_qe[i]] = (
q_start_idxs[i], torch.arange(
q_start_idxs[i] + cu_seqlens_qe[i] - cu_seqlens_qs[i], q_start_idxs[i], q_start_idxs[i] + cu_seqlens_qe[i] - cu_seqlens_qs[i], dtype=torch.int32, device=cu_seqlens_qs.device
dtype=torch.int32, )
device=cu_seqlens_qs.device) + 1) // kv_stride + cu_seqlens_ks[i] + 1
) // kv_stride + cu_seqlens_ks[i]
cu_seqlen_ke_for_each_q = torch.minimum(casual_cu_seqlen_ke_for_each_q, cu_seqlen_ke_for_each_q) cu_seqlen_ke_for_each_q = torch.minimum(casual_cu_seqlen_ke_for_each_q, cu_seqlen_ke_for_each_q)
return cu_seqlen_ke_for_each_q.int() return cu_seqlen_ke_for_each_q.int()
@tensor_cache @tensor_cache
def cal_ks_ke_from_cu_seqlen_qk(cu_seqlens_q: torch.LongTensor, def cal_ks_ke_from_cu_seqlen_qk(
cu_seqlens_k: torch.LongTensor = None, cu_seqlens_q: torch.LongTensor,
offs_q: torch.LongTensor = None, cu_seqlens_k: torch.LongTensor = None,
*, offs_q: torch.LongTensor = None,
seq_len: int, *,
kv_stride: int = 1, seq_len: int,
cp_rank: int = 0, kv_stride: int = 1,
cp_size: int = 1, cp_rank: int = 0,
balanced_cp=False): cp_size: int = 1,
''' balanced_cp=False,
):
"""
seq_len: seq len per cp rank seq_len: seq len per cp rank
balanced cp slice assignment: 0 1 2 3 3 2 1 0 balanced cp slice assignment: 0 1 2 3 3 2 1 0
''' """
n_seq = len(cu_seqlens_q) - 1 n_seq = len(cu_seqlens_q) - 1
assert n_seq > 0 assert n_seq > 0
assert cu_seqlens_q.shape == (n_seq + 1,) assert cu_seqlens_q.shape == (n_seq + 1,)
...@@ -170,10 +166,12 @@ def cal_ks_ke_from_cu_seqlen_qk(cu_seqlens_q: torch.LongTensor, ...@@ -170,10 +166,12 @@ def cal_ks_ke_from_cu_seqlen_qk(cu_seqlens_q: torch.LongTensor,
def f(x: torch.Tensor): def f(x: torch.Tensor):
chunks = x.chunk(cp_size * 2) chunks = x.chunk(cp_size * 2)
return torch.cat([ return torch.cat(
chunks[cp_rank], [
chunks[cp_size - cp_rank - 1], chunks[cp_rank],
]) chunks[cp_size - cp_rank - 1],
]
)
ks = f(ks) ks = f(ks)
ke = f(ke) ke = f(ke)
...@@ -189,8 +187,7 @@ def ceil_to_ue8m0(x: torch.Tensor): ...@@ -189,8 +187,7 @@ def ceil_to_ue8m0(x: torch.Tensor):
return torch.pow(2.0, torch.ceil(torch.log2(x.abs()))) return torch.pow(2.0, torch.ceil(torch.log2(x.abs())))
def per_custom_dims_cast_to_fp8(x: torch.Tensor, dims: Tuple[int], def per_custom_dims_cast_to_fp8(x: torch.Tensor, dims: Tuple[int], use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]:
use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]:
excluded_dims = tuple([i for i in range(x.dim()) if i not in set(dims)]) excluded_dims = tuple([i for i in range(x.dim()) if i not in set(dims)])
x_amax = x.abs().float().amax(dim=excluded_dims, keepdim=True).clamp(1e-4) x_amax = x.abs().float().amax(dim=excluded_dims, keepdim=True).clamp(1e-4)
sf = x_amax / 448.0 sf = x_amax / 448.0
...@@ -239,14 +236,18 @@ def generate_random_cu_seqlens(per_cp_seqlen, cp_size=4, cp_rank=3, kv_stride=1, ...@@ -239,14 +236,18 @@ def generate_random_cu_seqlens(per_cp_seqlen, cp_size=4, cp_rank=3, kv_stride=1,
total_seqlen - (cp_rank + 1) * per_chunk_seqlen, total_seqlen - (cp_rank + 1) * per_chunk_seqlen,
total_seqlen - cp_rank * per_chunk_seqlen, total_seqlen - cp_rank * per_chunk_seqlen,
) )
ks = torch.cat([ ks = torch.cat(
cu_seqlens_ks_for_each_q[slice_short], [
cu_seqlens_ks_for_each_q[slice_long], cu_seqlens_ks_for_each_q[slice_short],
]) cu_seqlens_ks_for_each_q[slice_long],
ke = torch.cat([ ]
cu_seqlens_ke_for_each_q[slice_short], )
cu_seqlens_ke_for_each_q[slice_long], ke = torch.cat(
]) [
cu_seqlens_ke_for_each_q[slice_short],
cu_seqlens_ke_for_each_q[slice_long],
]
)
assert len(ks) == len(ke) == per_cp_seqlen assert len(ks) == len(ke) == per_cp_seqlen
return ks, ke return ks, ke
...@@ -302,11 +303,9 @@ def assert_tensors_similar(x, y, eps=1e-8, name="tensor", raise_assert=True): ...@@ -302,11 +303,9 @@ def assert_tensors_similar(x, y, eps=1e-8, name="tensor", raise_assert=True):
raise_assert: Whether to raise assertion error on failure raise_assert: Whether to raise assertion error on failure
""" """
sim = calculate_tensor_similarity(x, y, name) sim = calculate_tensor_similarity(x, y, name)
diff = 1. - sim diff = 1.0 - sim
if not (0 <= diff <= eps): if not (0 <= diff <= eps):
print( print(f"\033[31mERROR: {name} similarity check failed, diff={diff:.2e} (threshold={eps:.2e})\033[0m")
f"\033[31mERROR: {name} similarity check failed, diff={diff:.2e} (threshold={eps:.2e})\033[0m"
)
if raise_assert: if raise_assert:
assert False # noqa: B011 assert False # noqa: B011
...@@ -316,11 +315,8 @@ if __name__ == "__main__": ...@@ -316,11 +315,8 @@ if __name__ == "__main__":
cu_seqlens = torch.randint(128, 4096, (1000,), dtype=torch.int32, device="cuda") cu_seqlens = torch.randint(128, 4096, (1000,), dtype=torch.int32, device="cuda")
last_idx = torch.where(cu_seqlens.cumsum(dim=0) >= seq_len)[0][0] last_idx = torch.where(cu_seqlens.cumsum(dim=0) >= seq_len)[0][0]
cu_seqlens_cumsum = cu_seqlens[:last_idx].cumsum(dim=0) cu_seqlens_cumsum = cu_seqlens[:last_idx].cumsum(dim=0)
cu_seqlens_qs = torch.cat( cu_seqlens_qs = torch.cat([torch.zeros(1, dtype=torch.int32, device=cu_seqlens.device), cu_seqlens_cumsum])
[torch.zeros(1, dtype=torch.int32, device=cu_seqlens.device), cu_seqlens_cumsum]) cu_seqlens_qe = torch.cat([cu_seqlens_cumsum, torch.ones(1, dtype=torch.int32, device=cu_seqlens.device) * seq_len])
cu_seqlens_qe = torch.cat(
[cu_seqlens_cumsum,
torch.ones(1, dtype=torch.int32, device=cu_seqlens.device) * seq_len])
from tilelang.profiler import do_bench from tilelang.profiler import do_bench
......
...@@ -39,12 +39,10 @@ def torch_convert_bit_twiddling(tensor): ...@@ -39,12 +39,10 @@ def torch_convert_bit_twiddling(tensor):
res0 = val_concat_expanded & mask res0 = val_concat_expanded & mask
res1 = (val_concat_expanded << 3) & mask res1 = (val_concat_expanded << 3) & mask
res2 = (val_concat_expanded << 6) & mask res2 = (val_concat_expanded << 6) & mask
res3 = ((val_concat_expanded << 1) & mask1) | ((val_concat_expanded >> 3) & mask2) | ( res3 = ((val_concat_expanded << 1) & mask1) | ((val_concat_expanded >> 3) & mask2) | ((val_concat_expanded >> 7) & mask3)
(val_concat_expanded >> 7) & mask3)
# Select the correct result based on position # Select the correct result based on position
bf16 = torch.where(pos == 0, res0, torch.where(pos == 1, res1, bf16 = torch.where(pos == 0, res0, torch.where(pos == 1, res1, torch.where(pos == 2, res2, res3)))
torch.where(pos == 2, res2, res3)))
# Convert to uint16 for .view(torch.bfloat16) # Convert to uint16 for .view(torch.bfloat16)
bf16_uint16 = (bf16 & 0xFFFF).to(torch.uint16) bf16_uint16 = (bf16 & 0xFFFF).to(torch.uint16)
...@@ -110,7 +108,7 @@ def print_bit(name, val): ...@@ -110,7 +108,7 @@ def print_bit(name, val):
val (torch.Tensor): A scalar PyTorch tensor (numeric) whose 32-bit binary representation will be shown. val (torch.Tensor): A scalar PyTorch tensor (numeric) whose 32-bit binary representation will be shown.
""" """
val_cpu = val.cpu().item() val_cpu = val.cpu().item()
binary_repr = f'{val_cpu:032b}' binary_repr = f"{val_cpu:032b}"
print(name, binary_repr) print(name, binary_repr)
...@@ -122,7 +120,7 @@ def calc_sim(x, y, name="tensor"): ...@@ -122,7 +120,7 @@ def calc_sim(x, y, name="tensor"):
x, y = x.data.double(), y.data.double() x, y = x.data.double(), y.data.double()
denominator = (x * x + y * y).sum() denominator = (x * x + y * y).sum()
if denominator == 0: if denominator == 0:
print_red_warning(f'{name} all zero') print_red_warning(f"{name} all zero")
return 1 return 1
sim = 2 * (x * y).sum() / denominator sim = 2 * (x * y).sum() / denominator
return sim return sim
...@@ -132,21 +130,19 @@ def assert_similar(x, y, eps=1e-8, name="tensor", data="", raise_assert=True): ...@@ -132,21 +130,19 @@ def assert_similar(x, y, eps=1e-8, name="tensor", data="", raise_assert=True):
x_mask = torch.isfinite(x) x_mask = torch.isfinite(x)
y_mask = torch.isfinite(y) y_mask = torch.isfinite(y)
if not torch.all(x_mask == y_mask): if not torch.all(x_mask == y_mask):
print_red_warning(f'{name} Error: isfinite mask mismatch') print_red_warning(f"{name} Error: isfinite mask mismatch")
if raise_assert: if raise_assert:
raise AssertionError raise AssertionError
if not torch.isclose( if not torch.isclose(x.masked_fill(x_mask, 0), y.masked_fill(y_mask, 0), rtol=0, atol=0, equal_nan=True).all():
x.masked_fill(x_mask, 0), y.masked_fill(y_mask, 0), rtol=0, atol=0, print_red_warning(f"{name} Error: nonfinite value mismatch")
equal_nan=True).all():
print_red_warning(f'{name} Error: nonfinite value mismatch')
if raise_assert: if raise_assert:
raise AssertionError raise AssertionError
x = x.masked_fill(~x_mask, 0) x = x.masked_fill(~x_mask, 0)
y = y.masked_fill(~y_mask, 0) y = y.masked_fill(~y_mask, 0)
sim = calc_sim(x, y, name) sim = calc_sim(x, y, name)
diff = (1. - sim).item() diff = (1.0 - sim).item()
print(f'{diff=}') print(f"{diff=}")
if not (0 <= diff <= eps): if not (0 <= diff <= eps):
print_red_warning(f'{name} Error: {diff=}') print_red_warning(f"{name} Error: {diff=}")
if raise_assert: if raise_assert:
raise AssertionError raise AssertionError
...@@ -24,6 +24,7 @@ def get_configs(): ...@@ -24,6 +24,7 @@ def get_configs():
the parameter name to its chosen value. the parameter name to its chosen value.
""" """
import itertools import itertools
iter_params = dict( iter_params = dict(
block_M=[64, 128, 256], block_M=[64, 128, 256],
block_N=[64, 128, 256], block_N=[64, 128, 256],
...@@ -32,63 +33,62 @@ def get_configs(): ...@@ -32,63 +33,62 @@ def get_configs():
threads=[128, 256, 512], threads=[128, 256, 512],
split=[1, 2], split=[1, 2],
) )
return [{ return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())]
k: v for k, v in zip(iter_params, values)
} for values in itertools.product(*iter_params.values())]
@tilelang.autotune(configs=get_configs(),) @tilelang.autotune(
configs=get_configs(),
)
@tilelang.jit( @tilelang.jit(
out_idx=[-1], out_idx=[-1],
pass_configs={ pass_configs={tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True},
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True
},
) )
def matmul(M, def matmul(
N, M,
K, N,
in_dtype, K,
out_dtype, in_dtype,
accum_dtype, out_dtype,
source_format='uint', accum_dtype,
num_bits=4, source_format="uint",
fast_dequant=True, num_bits=4,
block_M=256, fast_dequant=True,
block_N=128, block_M=256,
block_K=128, block_N=128,
num_stages=2, block_K=128,
threads=256, num_stages=2,
split=1): threads=256,
split=1,
):
"""
Builds a parameterized TileLang/TIR matrix-multiplication kernel that dequantizes 4-bit FP inputs to BF16 on-the-fly and computes C = A @ B^T.
This function returns a tiled, autotunable prim_func implementing a block-wise GEMM with shared-memory buffering and a pipelined K-loop. The kernel accepts:
- A: dense input of shape (M, K) with dtype `in_dtype`.
- B: packed quantized input of shape (N, QK) where QK = K / (8 / num_bits) stored as `uint8`.
- C: output of shape (M, N) with dtype `out_dtype`.
The generated kernel supports two dequantization paths:
- fast_dequant (fast_dequant=True): calls an external mxfp dequantization intrinsic (twiddling-based) loaded from a C source returned by get_mxfp_intrin_group.
- simple dequant (fast_dequant=False): performs a pure-TIR FP4 -> BF16 conversion per element.
Important behavior and requirements:
- num_bits (default 4) is the bit-width of the quantized elements; storage_dtype is uint8 and num_elems_per_byte = 8 // num_bits.
- QK = K // num_elems_per_byte and Block_QK = block_K // num_elems_per_byte determine B and shared-buffer shapes.
- Asserts that K % (block_K * split) == 0; K must be divisible by block_K * split for the tiling to be valid.
- When fast_dequant is True, a valid mxfp intrinsic group (C source and function name) must be available via tilelang.quantize.get_mxfp_intrin_group.
- The kernel launches a 2D grid over ceildiv(N, block_N) and ceildiv(M, block_M) and uses `threads` threads per block with `num_stages` pipeline stages.
Parameters that alter kernel layout/behavior (brief):
- block_M, block_N, block_K: tile sizes for M, N, and K dimensions.
- num_stages: number of software pipeline stages for the K-loop.
- threads: number of threads used per kernel block.
- split: extra K-splitting factor; K must be divisible by block_K * split.
- source_format, num_bits: describe the quantized data layout passed to the mxfp intrinsics.
Returns:
A TileLang/TIR prim_func (the compiled `main`) implementing the described dequantize-then-GEMM kernel.
""" """
Builds a parameterized TileLang/TIR matrix-multiplication kernel that dequantizes 4-bit FP inputs to BF16 on-the-fly and computes C = A @ B^T.
This function returns a tiled, autotunable prim_func implementing a block-wise GEMM with shared-memory buffering and a pipelined K-loop. The kernel accepts:
- A: dense input of shape (M, K) with dtype `in_dtype`.
- B: packed quantized input of shape (N, QK) where QK = K / (8 / num_bits) stored as `uint8`.
- C: output of shape (M, N) with dtype `out_dtype`.
The generated kernel supports two dequantization paths:
- fast_dequant (fast_dequant=True): calls an external mxfp dequantization intrinsic (twiddling-based) loaded from a C source returned by get_mxfp_intrin_group.
- simple dequant (fast_dequant=False): performs a pure-TIR FP4 -> BF16 conversion per element.
Important behavior and requirements:
- num_bits (default 4) is the bit-width of the quantized elements; storage_dtype is uint8 and num_elems_per_byte = 8 // num_bits.
- QK = K // num_elems_per_byte and Block_QK = block_K // num_elems_per_byte determine B and shared-buffer shapes.
- Asserts that K % (block_K * split) == 0; K must be divisible by block_K * split for the tiling to be valid.
- When fast_dequant is True, a valid mxfp intrinsic group (C source and function name) must be available via tilelang.quantize.get_mxfp_intrin_group.
- The kernel launches a 2D grid over ceildiv(N, block_N) and ceildiv(M, block_M) and uses `threads` threads per block with `num_stages` pipeline stages.
Parameters that alter kernel layout/behavior (brief):
- block_M, block_N, block_K: tile sizes for M, N, and K dimensions.
- num_stages: number of software pipeline stages for the K-loop.
- threads: number of threads used per kernel block.
- split: extra K-splitting factor; K must be divisible by block_K * split.
- source_format, num_bits: describe the quantized data layout passed to the mxfp intrinsics.
Returns:
A TileLang/TIR prim_func (the compiled `main`) implementing the described dequantize-then-GEMM kernel.
"""
num_elems_per_byte = 8 // num_bits num_elems_per_byte = 8 // num_bits
storage_dtype = "uint8" storage_dtype = "uint8"
...@@ -189,8 +189,7 @@ def matmul(M, ...@@ -189,8 +189,7 @@ def matmul(M,
# Finally, store the dequantized data to shared memory. # Finally, store the dequantized data to shared memory.
for v in T.vectorized(0, local_size): for v in T.vectorized(0, local_size):
index = i * threads * local_size + tx * local_size + v index = i * threads * local_size + tx * local_size + v
B_dequantize_shared[index // block_K, B_dequantize_shared[index // block_K, index % block_K] = B_dequantize_local_thread[v]
index % block_K] = B_dequantize_local_thread[v]
return fast_dequant_bf16_fp4_twiddling return fast_dequant_bf16_fp4_twiddling
...@@ -215,30 +214,29 @@ def matmul(M, ...@@ -215,30 +214,29 @@ def matmul(M,
assert in_dtype in ["fp4"] assert in_dtype in ["fp4"]
assert out_dtype in ["bfloat16"] assert out_dtype in ["bfloat16"]
def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale: tir.PrimExpr, dtype: str):
scale: tir.PrimExpr, dtype: str):
""" """
Convert a 4-bit FP4 value packed in a uint8 byte into a bfloat16 value. Convert a 4-bit FP4 value packed in a uint8 byte into a bfloat16 value.
This helper extracts the 4-bit field located at the bit position `pos` within the This helper extracts the 4-bit field located at the bit position `pos` within the
byte `val`, interprets it as an FP4 (sign, exponent, mantissa) value, applies an byte `val`, interprets it as an FP4 (sign, exponent, mantissa) value, applies an
exponent `scale` offset to align it with bfloat16 exponent bias, clamps the exponent `scale` offset to align it with bfloat16 exponent bias, clamps the
resulting exponent to 8 bits, and returns the assembled bfloat16 bit pattern. resulting exponent to 8 bits, and returns the assembled bfloat16 bit pattern.
Parameters: Parameters:
nbit (int): Number of bits in the packed element; must be 4. nbit (int): Number of bits in the packed element; must be 4.
val (tir.PrimExpr): A uint8 value containing packed FP4 elements. val (tir.PrimExpr): A uint8 value containing packed FP4 elements.
pos (tir.PrimExpr): Index (0-based) of which FP4 nibble inside `val` to extract. pos (tir.PrimExpr): Index (0-based) of which FP4 nibble inside `val` to extract.
scale (tir.PrimExpr): Exponent offset applied when converting FP4 exponent to bfloat16. scale (tir.PrimExpr): Exponent offset applied when converting FP4 exponent to bfloat16.
dtype (str): Target dtype string; must be "bfloat16". dtype (str): Target dtype string; must be "bfloat16".
Returns: Returns:
tir.PrimExpr: A bfloat16-typed PrimExpr containing the converted value. tir.PrimExpr: A bfloat16-typed PrimExpr containing the converted value.
Notes: Notes:
- The function asserts `nbit == 4`, `dtype == "bfloat16"`, and that `val.dtype` is "uint8". - The function asserts `nbit == 4`, `dtype == "bfloat16"`, and that `val.dtype` is "uint8".
- The conversion uses a fixed mapping from FP4 exponent/mantissa layout into bfloat16 - The conversion uses a fixed mapping from FP4 exponent/mantissa layout into bfloat16
bit fields and clamps the computed exponent to fit into 8 bits. bit fields and clamps the computed exponent to fit into 8 bits.
""" """
assert nbit == 4 assert nbit == 4
assert dtype == "bfloat16" assert dtype == "bfloat16"
...@@ -254,8 +252,9 @@ def matmul(M, ...@@ -254,8 +252,9 @@ def matmul(M,
e_bf16 = T.min(e_bf16 + scale, tir.const((1 << 8) - 1, "uint16")) e_bf16 = T.min(e_bf16 + scale, tir.const((1 << 8) - 1, "uint16"))
m_f4 = f4 & tir.const(1, "uint16") m_f4 = f4 & tir.const(1, "uint16")
val_bf16 = tir.reinterpret( val_bf16 = tir.reinterpret(
"bfloat16", ((((s << tir.const(8, "uint16")) | e_bf16) << tir.const(7, "uint16")) "bfloat16",
| (m_f4 << tir.const(6, "uint16"))).astype("uint16")) ((((s << tir.const(8, "uint16")) | e_bf16) << tir.const(7, "uint16")) | (m_f4 << tir.const(6, "uint16"))).astype("uint16"),
)
return val_bf16 return val_bf16
@T.macro @T.macro
...@@ -292,32 +291,32 @@ def matmul(M, ...@@ -292,32 +291,32 @@ def matmul(M,
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, storage_dtype), B: T.Tensor(B_shape, storage_dtype),
C: T.Tensor((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
""" """
Kernel entry for the tiled, pipelined matmul used by the generated prim_func. Kernel entry for the tiled, pipelined matmul used by the generated prim_func.
This function implements a block-wise GEMM over a 2D grid (grid dims: ceildiv(N, block_N) x ceildiv(M, block_M)) with a thread block of `threads`. For each output block it: This function implements a block-wise GEMM over a 2D grid (grid dims: ceildiv(N, block_N) x ceildiv(M, block_M)) with a thread block of `threads`. For each output block it:
- Allocates shared buffers for A, the packed/quantized B, and a dequantized B tile. - Allocates shared buffers for A, the packed/quantized B, and a dequantized B tile.
- Allocates a fragment accumulator (C_local) and a shared output tile (C_shared) with a swizzled layout. - Allocates a fragment accumulator (C_local) and a shared output tile (C_shared) with a swizzled layout.
- Pipelines over K in chunks of `block_K` for `num_stages` stages: - Pipelines over K in chunks of `block_K` for `num_stages` stages:
- Loads A and packed B tiles into shared memory. - Loads A and packed B tiles into shared memory.
- Dequantizes B into B_dequantize_shared using either the fast (twiddling/external) or the simple (pure-TIR) dequantization routine. - Dequantizes B into B_dequantize_shared using either the fast (twiddling/external) or the simple (pure-TIR) dequantization routine.
- Performs a GEMM accumulating into C_local with B transposed. - Performs a GEMM accumulating into C_local with B transposed.
- Stores the accumulated block from C_local back to the global output C via C_shared. - Stores the accumulated block from C_local back to the global output C via C_shared.
Parameters: Parameters:
- A: input tile of shape (M, K) with dtype `in_dtype`. - A: input tile of shape (M, K) with dtype `in_dtype`.
- B: packed/quantized input of shape (N, QK) with storage dtype `storage_dtype` (quantized FP4 packing). - B: packed/quantized input of shape (N, QK) with storage dtype `storage_dtype` (quantized FP4 packing).
- C: output tensor of shape (M, N) with dtype `out_dtype`. - C: output tensor of shape (M, N) with dtype `out_dtype`.
Side effects: Side effects:
- Writes the computed output block into the global tensor `C`. - Writes the computed output block into the global tensor `C`.
- Uses and updates shared memory buffers and per-thread accumulators. - Uses and updates shared memory buffers and per-thread accumulators.
No value is returned. No value is returned.
""" """
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
...@@ -327,9 +326,11 @@ def matmul(M, ...@@ -327,9 +326,11 @@ def matmul(M,
C_local = T.alloc_fragment((block_M, block_N), accum_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
C_shared = T.alloc_shared((block_M, block_N), out_dtype) C_shared = T.alloc_shared((block_M, block_N), out_dtype)
T.annotate_layout({ T.annotate_layout(
C_shared: tilelang.layout.make_swizzled_layout(C_shared), {
}) C_shared: tilelang.layout.make_swizzled_layout(C_shared),
}
)
T.clear(C_local) T.clear(C_local)
for k in T.Pipelined(K // block_K, num_stages=num_stages): for k in T.Pipelined(K // block_K, num_stages=num_stages):
...@@ -344,7 +345,7 @@ def matmul(M, ...@@ -344,7 +345,7 @@ def matmul(M,
T.gemm(A_shared, B_dequantize_shared, C_local, transpose_B=True) T.gemm(A_shared, B_dequantize_shared, C_local, transpose_B=True)
T.copy(C_local, C_shared) T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M:(by + 1) * block_M, bx * block_N:(bx + 1) * block_N]) T.copy(C_shared, C[by * block_M : (by + 1) * block_M, bx * block_N : (bx + 1) * block_N])
return main return main
...@@ -409,8 +410,7 @@ def main(m=256, n=256, k=256, fast_dequant=True, tune=False): ...@@ -409,8 +410,7 @@ def main(m=256, n=256, k=256, fast_dequant=True, tune=False):
""" """
total_flops = 2 * m * n * k total_flops = 2 * m * n * k
if tune: if tune:
kernel = matmul( kernel = matmul(m, n, k, "bfloat16", "bfloat16", "float32", num_bits=4, fast_dequant=fast_dequant)
m, n, k, "bfloat16", "bfloat16", "float32", num_bits=4, fast_dequant=fast_dequant)
else: else:
kernel = matmul( kernel = matmul(
m, m,
...@@ -426,7 +426,8 @@ def main(m=256, n=256, k=256, fast_dequant=True, tune=False): ...@@ -426,7 +426,8 @@ def main(m=256, n=256, k=256, fast_dequant=True, tune=False):
block_K=128, block_K=128,
num_stages=2, num_stages=2,
threads=256, threads=256,
split=1) split=1,
)
profiler = kernel.get_profiler(tilelang.TensorSupplyType.Auto) profiler = kernel.get_profiler(tilelang.TensorSupplyType.Auto)
if fast_dequant: if fast_dequant:
profiler.assert_allclose(ref_program_twiddling, rtol=0.01, atol=0.01) profiler.assert_allclose(ref_program_twiddling, rtol=0.01, atol=0.01)
......
...@@ -7,29 +7,28 @@ import torch ...@@ -7,29 +7,28 @@ import torch
from dequantize_utils import torch_convert_bit_twiddling, torch_convert from dequantize_utils import torch_convert_bit_twiddling, torch_convert
def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale: tir.PrimExpr, def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale: tir.PrimExpr, dtype: str):
dtype: str):
""" """
Convert a 4-bit field packed in a uint8 into a bfloat16 value, applying an exponent scale. Convert a 4-bit field packed in a uint8 into a bfloat16 value, applying an exponent scale.
This helper extracts a 4-bit nibble from `val` at byte-nibble position `pos`, interprets its This helper extracts a 4-bit nibble from `val` at byte-nibble position `pos`, interprets its
bits as a sign/exponent/mantissa in the 4-bit custom FP4 layout, adjusts the exponent by bits as a sign/exponent/mantissa in the 4-bit custom FP4 layout, adjusts the exponent by
`scale` (clamped to an 8-bit range), and assembles the corresponding bfloat16 representation. `scale` (clamped to an 8-bit range), and assembles the corresponding bfloat16 representation.
Parameters: Parameters:
nbit (int): Number of bits in the packed field (must be 4). nbit (int): Number of bits in the packed field (must be 4).
val (tir.PrimExpr): Packed input value of dtype `uint8` containing one or more 4-bit fields. val (tir.PrimExpr): Packed input value of dtype `uint8` containing one or more 4-bit fields.
pos (tir.PrimExpr): Index of the nibble within `val` (used to shift/extract the 4-bit field). pos (tir.PrimExpr): Index of the nibble within `val` (used to shift/extract the 4-bit field).
scale (tir.PrimExpr): Per-element exponent adjustment added to the extracted exponent (uint-like). scale (tir.PrimExpr): Per-element exponent adjustment added to the extracted exponent (uint-like).
dtype (str): Destination dtype string (must be "bfloat16"). dtype (str): Destination dtype string (must be "bfloat16").
Returns: Returns:
tir.PrimExpr: The resulting value reinterpreted as `bfloat16`. tir.PrimExpr: The resulting value reinterpreted as `bfloat16`.
Notes: Notes:
- Preconditions are enforced via assertions: nbit == 4, dtype == "bfloat16", and val.dtype == "uint8". - Preconditions are enforced via assertions: nbit == 4, dtype == "bfloat16", and val.dtype == "uint8".
- The function clamps the adjusted exponent to the 8-bit range before assembling the bfloat16 bit pattern. - The function clamps the adjusted exponent to the 8-bit range before assembling the bfloat16 bit pattern.
""" """
assert nbit == 4 assert nbit == 4
assert dtype == "bfloat16" assert dtype == "bfloat16"
assert val.dtype == "uint8" assert val.dtype == "uint8"
...@@ -43,9 +42,10 @@ def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale ...@@ -43,9 +42,10 @@ def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale
# To handle the overflow, we may use the min function to limit the exponential part to 8 bits # To handle the overflow, we may use the min function to limit the exponential part to 8 bits
# e_bf16 = T.min(e_bf16 + scale, tir.const((1 << 8) - 1, "uint16")) # e_bf16 = T.min(e_bf16 + scale, tir.const((1 << 8) - 1, "uint16"))
m_f4 = f4 & tir.const(1, "uint16") m_f4 = f4 & tir.const(1, "uint16")
val_bf16 = tir.reinterpret("bfloat16", val_bf16 = tir.reinterpret(
((((s << tir.const(8, "uint16")) | e_bf16) << tir.const(7, "uint16")) "bfloat16",
| (m_f4 << tir.const(6, "uint16"))).astype("uint16")) ((((s << tir.const(8, "uint16")) | e_bf16) << tir.const(7, "uint16")) | (m_f4 << tir.const(6, "uint16"))).astype("uint16"),
)
return val_bf16 return val_bf16
...@@ -65,6 +65,7 @@ def get_configs(): ...@@ -65,6 +65,7 @@ def get_configs():
List[dict]: A list of configuration dictionaries covering all combinations. List[dict]: A list of configuration dictionaries covering all combinations.
""" """
import itertools import itertools
iter_params = dict( iter_params = dict(
block_M=[64, 128, 256], block_M=[64, 128, 256],
block_N=[64, 128, 256], block_N=[64, 128, 256],
...@@ -73,67 +74,71 @@ def get_configs(): ...@@ -73,67 +74,71 @@ def get_configs():
threads=[128, 256, 512], threads=[128, 256, 512],
split=[1, 2], split=[1, 2],
) )
return [{ return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())]
k: v for k, v in zip(iter_params, values)
} for values in itertools.product(*iter_params.values())]
@tilelang.autotune(
configs=get_configs(),
@tilelang.autotune(configs=get_configs(),) )
@tilelang.jit(out_idx=[-1],) @tilelang.jit(
def matmul(M, out_idx=[-1],
N, )
K, def matmul(
in_dtype, M,
out_dtype, N,
accum_dtype, K,
source_format='uint', in_dtype,
num_bits=4, out_dtype,
scale_size=32, accum_dtype,
fast_dequant=True, source_format="uint",
with_bias=False, num_bits=4,
block_M=256, scale_size=32,
block_N=128, fast_dequant=True,
block_K=128, with_bias=False,
num_stages=2, block_M=256,
threads=256, block_N=128,
split=1): block_K=128,
num_stages=2,
threads=256,
split=1,
):
""" """
Construct and return a tiled matrix-multiply TIR kernel that multiplies A (shape MxK) by a quantized B (shape Nx(QK)) and writes an MxN output in out_dtype. Construct and return a tiled matrix-multiply TIR kernel that multiplies A (shape MxK) by a quantized B (shape Nx(QK)) and writes an MxN output in out_dtype.
The generated kernel accepts:
- A: dense matrix with element type `in_dtype`.
- B: packed quantized matrix stored as uint8 with `num_bits` bits per element (QK = K / (8/num_bits)).
- Scale: per-block scale/exponent information used to dequantize B.
The kernel dequantizes B to a working floating format (out_dtype/accum_dtype) using one of two paths:
- fast_dequant (True): uses an external, hardware/implementation-specific intrinsic group (twiddling) for batch dequantization.
- fast_dequant (False): uses a simple elementwise dequantization helper.
Parameters:
M, N, K (int): matrix dimensions (A is MxK, result is MxN). K must be divisible by (block_K * split).
in_dtype (str): element type of A (e.g., "fp4" in this file).
out_dtype (str): output tensor element type (e.g., "bfloat16").
accum_dtype (str): accumulation type used for the inner GEMM.
source_format (str, optional): format string passed to intrinsic selector (default "uint").
num_bits (int, optional): number of bits per quantized element in B (default 4).
scale_size (int, optional): number of elements grouped per scale entry (default 32).
fast_dequant (bool, optional): choose the fast intrinsic dequantization path when available (default True).
block_M, block_N, block_K (int, optional): tile sizes for M, N, and K dimensions (defaults 256, 128, 128).
num_stages (int, optional): pipelining stages for K loop (default 2).
threads (int, optional): threads per block used by the kernel (default 256).
split (int, optional): split factor along K used by the scheduler (default 1).
with_bias (bool, optional): whether to add Bias to the output (default False).
Returns:
A T.prim_func implementing the tiled, pipelined GEMM that:
- loads tiled blocks of A and packed B to shared memory,
- dequantizes B via the chosen path into a shared dequantized tile,
- performs a tiled GEMM accumulating into local fragments,
- writes the final MxN block to the global output tensor.
Notes: The generated kernel accepts:
- The function queries an intrinsic group to obtain a fast dequantization implementation when fast_dequant is enabled; that intrinsic must supply a valid C source and function name. - A: dense matrix with element type `in_dtype`.
- The kernel layout uses swizzled shared-memory layouts for A, B, and the shared C tile. - B: packed quantized matrix stored as uint8 with `num_bits` bits per element (QK = K / (8/num_bits)).
- An assertion enforces that K % (block_K * split) == 0. - Scale: per-block scale/exponent information used to dequantize B.
The kernel dequantizes B to a working floating format (out_dtype/accum_dtype) using one of two paths:
- fast_dequant (True): uses an external, hardware/implementation-specific intrinsic group (twiddling) for batch dequantization.
- fast_dequant (False): uses a simple elementwise dequantization helper.
Parameters:
M, N, K (int): matrix dimensions (A is MxK, result is MxN). K must be divisible by (block_K * split).
in_dtype (str): element type of A (e.g., "fp4" in this file).
out_dtype (str): output tensor element type (e.g., "bfloat16").
accum_dtype (str): accumulation type used for the inner GEMM.
source_format (str, optional): format string passed to intrinsic selector (default "uint").
num_bits (int, optional): number of bits per quantized element in B (default 4).
scale_size (int, optional): number of elements grouped per scale entry (default 32).
fast_dequant (bool, optional): choose the fast intrinsic dequantization path when available (default True).
block_M, block_N, block_K (int, optional): tile sizes for M, N, and K dimensions (defaults 256, 128, 128).
num_stages (int, optional): pipelining stages for K loop (default 2).
threads (int, optional): threads per block used by the kernel (default 256).
split (int, optional): split factor along K used by the scheduler (default 1).
with_bias (bool, optional): whether to add Bias to the output (default False).
Returns:
A T.prim_func implementing the tiled, pipelined GEMM that:
- loads tiled blocks of A and packed B to shared memory,
- dequantizes B via the chosen path into a shared dequantized tile,
- performs a tiled GEMM accumulating into local fragments,
- writes the final MxN block to the global output tensor.
Notes:
- The function queries an intrinsic group to obtain a fast dequantization implementation when fast_dequant is enabled; that intrinsic must supply a valid C source and function name.
- The kernel layout uses swizzled shared-memory layouts for A, B, and the shared C tile.
- An assertion enforces that K % (block_K * split) == 0.
""" """
num_elems_per_byte = 8 // num_bits num_elems_per_byte = 8 // num_bits
storage_dtype = "uint8" storage_dtype = "uint8"
...@@ -150,6 +155,7 @@ def matmul(M, ...@@ -150,6 +155,7 @@ def matmul(M,
assert K % (block_K * split) == 0 assert K % (block_K * split) == 0
from tilelang.quantize import get_mxfp_intrin_group from tilelang.quantize import get_mxfp_intrin_group
# fast_dequant_bf16_fp4_twiddling # fast_dequant_bf16_fp4_twiddling
mxfp_intrin_info = get_mxfp_intrin_group( mxfp_intrin_info = get_mxfp_intrin_group(
out_dtype=in_dtype, out_dtype=in_dtype,
...@@ -252,8 +258,7 @@ def matmul(M, ...@@ -252,8 +258,7 @@ def matmul(M,
for v in T.vectorized(0, local_size): for v in T.vectorized(0, local_size):
index = i * threads * local_size + tx * local_size + v index = i * threads * local_size + tx * local_size + v
B_dequantize_shared[index // block_K, B_dequantize_shared[index // block_K, index % block_K] = B_dequantize_local_thread[v]
index % block_K] = B_dequantize_local_thread[v]
return fast_dequant_bf16_fp4_twiddling return fast_dequant_bf16_fp4_twiddling
...@@ -301,33 +306,32 @@ def matmul(M, ...@@ -301,33 +306,32 @@ def matmul(M,
B_local[i, j // num_elems_per_byte], B_local[i, j // num_elems_per_byte],
j % num_elems_per_byte, j % num_elems_per_byte,
Scale[ Scale[
bx * block_N + i, k * block_K // scale_size + j // bx * block_N + i, k * block_K // scale_size + j // scale_size
scale_size], # Scale is the exponential part, within the representation of uint8 ], # Scale is the exponential part, within the representation of uint8
dtype=out_dtype, dtype=out_dtype,
) * T.shift_left( ) * T.shift_left(1, (Scale[bx * block_N + i, k * block_K // scale_size + j // scale_size]))
1, (Scale[bx * block_N + i, k * block_K // scale_size + j // scale_size]))
T.copy(B_dequantize_local, B_dequantize_shared) T.copy(B_dequantize_local, B_dequantize_shared)
return simple_dequant_bf16_fp4 return simple_dequant_bf16_fp4
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, storage_dtype), B: T.Tensor(B_shape, storage_dtype),
Scale: T.Tensor(Scale_shape, storage_dtype), Scale: T.Tensor(Scale_shape, storage_dtype),
Bias: T.Tensor(Bias_shape, out_dtype), Bias: T.Tensor(Bias_shape, out_dtype),
C: T.Tensor((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
""" """
Tiled, pipelined kernel entry that multiplies A with a quantized B (with per-block Scale) producing C. Tiled, pipelined kernel entry that multiplies A with a quantized B (with per-block Scale) producing C.
This prim-level kernel implements a blocked, multi-threaded matmul: it loads tiles of A and the packed/quantized B into shared memory, dequantizes B (either via the fast intrinsic twiddling path or the simple per-element path), performs a block GEMM (with B transposed), and writes the accumulated block results into the output tensor C. The kernel allocates shared buffers for A, B, and the dequantized B, and a local fragment for accumulation; it runs over K in pipelined stages and expects the provided shapes and dtypes to match the tiling parameters used to build the function. This prim-level kernel implements a blocked, multi-threaded matmul: it loads tiles of A and the packed/quantized B into shared memory, dequantizes B (either via the fast intrinsic twiddling path or the simple per-element path), performs a block GEMM (with B transposed), and writes the accumulated block results into the output tensor C. The kernel allocates shared buffers for A, B, and the dequantized B, and a local fragment for accumulation; it runs over K in pipelined stages and expects the provided shapes and dtypes to match the tiling parameters used to build the function.
Parameters are self-descriptive in the signature; notable behaviors: Parameters are self-descriptive in the signature; notable behaviors:
- B is stored in a compact uint8-packed layout (num_bits per element) and is dequantized using Scale before GEMM. - B is stored in a compact uint8-packed layout (num_bits per element) and is dequantized using Scale before GEMM.
- The selected dequantization path is controlled by the outer-scope flag `fast_dequant`. - The selected dequantization path is controlled by the outer-scope flag `fast_dequant`.
- The GEMM uses transpose_B=True (i.e., multiplies A · B^T after dequantization). - The GEMM uses transpose_B=True (i.e., multiplies A · B^T after dequantization).
- The function writes results in-place into C. - The function writes results in-place into C.
""" """
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
...@@ -337,23 +341,26 @@ def matmul(M, ...@@ -337,23 +341,26 @@ def matmul(M,
C_local = T.alloc_fragment((block_M, block_N), accum_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
C_shared = T.alloc_shared((block_M, block_N), out_dtype) C_shared = T.alloc_shared((block_M, block_N), out_dtype)
T.annotate_layout({ T.annotate_layout(
A_shared: tilelang.layout.make_swizzled_layout(A_shared), {
B_shared: tilelang.layout.make_swizzled_layout(B_shared), A_shared: tilelang.layout.make_swizzled_layout(A_shared),
C_shared: tilelang.layout.make_swizzled_layout(C_shared), B_shared: tilelang.layout.make_swizzled_layout(B_shared),
}) C_shared: tilelang.layout.make_swizzled_layout(C_shared),
}
)
if with_bias: if with_bias:
T.annotate_layout({ T.annotate_layout(
Bias_shared: tilelang.layout.make_swizzled_layout(Bias_shared), {
}) Bias_shared: tilelang.layout.make_swizzled_layout(Bias_shared),
}
)
if threads == 512: if threads == 512:
T.disable_warp_group_reg_alloc() T.disable_warp_group_reg_alloc()
if with_bias: if with_bias:
T.copy(Bias[by * block_M:(by + 1) * block_M, bx * block_N:(bx + 1) * block_N], T.copy(Bias[by * block_M : (by + 1) * block_M, bx * block_N : (bx + 1) * block_N], Bias_shared)
Bias_shared)
T.copy(Bias_shared, C_local) T.copy(Bias_shared, C_local)
else: else:
T.clear(C_local) T.clear(C_local)
...@@ -368,7 +375,7 @@ def matmul(M, ...@@ -368,7 +375,7 @@ def matmul(M,
T.gemm(A_shared, B_dequantize_shared, C_local, transpose_B=True) T.gemm(A_shared, B_dequantize_shared, C_local, transpose_B=True)
T.copy(C_local, C_shared) T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M:(by + 1) * block_M, bx * block_N:(bx + 1) * block_N]) T.copy(C_shared, C[by * block_M : (by + 1) * block_M, bx * block_N : (bx + 1) * block_N])
return main return main
...@@ -389,7 +396,7 @@ def ref_program_twiddling(A, qB, Scale, Bias=None): ...@@ -389,7 +396,7 @@ def ref_program_twiddling(A, qB, Scale, Bias=None):
""" """
dtypeC = "bfloat16" dtypeC = "bfloat16"
B = torch_convert_bit_twiddling(qB) B = torch_convert_bit_twiddling(qB)
B *= 2**(Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)]) B *= 2 ** (Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)])
C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
C = C.to(torch.__getattribute__(dtypeC)) C = C.to(torch.__getattribute__(dtypeC))
return C return C
...@@ -412,7 +419,7 @@ def ref_program_twiddling_with_bias(A, qB, Scale, Bias): ...@@ -412,7 +419,7 @@ def ref_program_twiddling_with_bias(A, qB, Scale, Bias):
""" """
dtypeC = "bfloat16" dtypeC = "bfloat16"
B = torch_convert_bit_twiddling(qB) B = torch_convert_bit_twiddling(qB)
B *= 2**(Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)]) B *= 2 ** (Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)])
C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + Bias C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + Bias
C = C.to(torch.__getattribute__(dtypeC)) C = C.to(torch.__getattribute__(dtypeC))
return C return C
...@@ -436,7 +443,7 @@ def ref_program_simple(A, qB, Scale, Bias=None): ...@@ -436,7 +443,7 @@ def ref_program_simple(A, qB, Scale, Bias=None):
""" """
dtypeC = "bfloat16" dtypeC = "bfloat16"
B = torch_convert(qB) B = torch_convert(qB)
B *= 2**(Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)]) B *= 2 ** (Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)])
C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
C = C.to(torch.__getattribute__(dtypeC)) C = C.to(torch.__getattribute__(dtypeC))
return C return C
...@@ -464,7 +471,7 @@ def ref_program_simple_with_bias(A, qB, Scale, Bias): ...@@ -464,7 +471,7 @@ def ref_program_simple_with_bias(A, qB, Scale, Bias):
""" """
dtypeC = "bfloat16" dtypeC = "bfloat16"
B = torch_convert(qB) B = torch_convert(qB)
B *= 2**(Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)]) B *= 2 ** (Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)])
C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + Bias C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + Bias
C = C.to(torch.__getattribute__(dtypeC)) C = C.to(torch.__getattribute__(dtypeC))
return C return C
...@@ -491,16 +498,8 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False, ...@@ -491,16 +498,8 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False,
if tune: if tune:
kernel = matmul( kernel = matmul(
m, m, n, k, "bfloat16", "bfloat16", "float32", num_bits=4, scale_size=scale_size, fast_dequant=fast_dequant, with_bias=with_bias
n, )
k,
"bfloat16",
"bfloat16",
"float32",
num_bits=4,
scale_size=scale_size,
fast_dequant=fast_dequant,
with_bias=with_bias)
else: else:
kernel = matmul( kernel = matmul(
m, m,
...@@ -518,7 +517,8 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False, ...@@ -518,7 +517,8 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False,
threads=256, threads=256,
split=1, split=1,
fast_dequant=fast_dequant, fast_dequant=fast_dequant,
with_bias=with_bias) with_bias=with_bias,
)
profiler = kernel.get_profiler(tilelang.TensorSupplyType.Auto) profiler = kernel.get_profiler(tilelang.TensorSupplyType.Auto)
......
...@@ -7,29 +7,28 @@ import torch ...@@ -7,29 +7,28 @@ import torch
from dequantize_utils import torch_convert_bit_twiddling, torch_convert from dequantize_utils import torch_convert_bit_twiddling, torch_convert
def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale: tir.PrimExpr, def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale: tir.PrimExpr, dtype: str):
dtype: str):
""" """
Convert a 4-bit field packed in a uint8 into a bfloat16 value, applying an exponent scale. Convert a 4-bit field packed in a uint8 into a bfloat16 value, applying an exponent scale.
This helper extracts a 4-bit nibble from `val` at byte-nibble position `pos`, interprets its This helper extracts a 4-bit nibble from `val` at byte-nibble position `pos`, interprets its
bits as a sign/exponent/mantissa in the 4-bit custom FP4 layout, adjusts the exponent by bits as a sign/exponent/mantissa in the 4-bit custom FP4 layout, adjusts the exponent by
`scale` (clamped to an 8-bit range), and assembles the corresponding bfloat16 representation. `scale` (clamped to an 8-bit range), and assembles the corresponding bfloat16 representation.
Parameters: Parameters:
nbit (int): Number of bits in the packed field (must be 4). nbit (int): Number of bits in the packed field (must be 4).
val (tir.PrimExpr): Packed input value of dtype `uint8` containing one or more 4-bit fields. val (tir.PrimExpr): Packed input value of dtype `uint8` containing one or more 4-bit fields.
pos (tir.PrimExpr): Index of the nibble within `val` (used to shift/extract the 4-bit field). pos (tir.PrimExpr): Index of the nibble within `val` (used to shift/extract the 4-bit field).
scale (tir.PrimExpr): Per-element exponent adjustment added to the extracted exponent (uint-like). scale (tir.PrimExpr): Per-element exponent adjustment added to the extracted exponent (uint-like).
dtype (str): Destination dtype string (must be "bfloat16"). dtype (str): Destination dtype string (must be "bfloat16").
Returns: Returns:
tir.PrimExpr: The resulting value reinterpreted as `bfloat16`. tir.PrimExpr: The resulting value reinterpreted as `bfloat16`.
Notes: Notes:
- Preconditions are enforced via assertions: nbit == 4, dtype == "bfloat16", and val.dtype == "uint8". - Preconditions are enforced via assertions: nbit == 4, dtype == "bfloat16", and val.dtype == "uint8".
- The function clamps the adjusted exponent to the 8-bit range before assembling the bfloat16 bit pattern. - The function clamps the adjusted exponent to the 8-bit range before assembling the bfloat16 bit pattern.
""" """
assert nbit == 4 assert nbit == 4
assert dtype == "bfloat16" assert dtype == "bfloat16"
assert val.dtype == "uint8" assert val.dtype == "uint8"
...@@ -43,9 +42,10 @@ def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale ...@@ -43,9 +42,10 @@ def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale
# To handle the overflow, we may use the min function to limit the exponential part to 8 bits # To handle the overflow, we may use the min function to limit the exponential part to 8 bits
# e_bf16 = T.min(e_bf16 + scale, tir.const((1 << 8) - 1, "uint16")) # e_bf16 = T.min(e_bf16 + scale, tir.const((1 << 8) - 1, "uint16"))
m_f4 = f4 & tir.const(1, "uint16") m_f4 = f4 & tir.const(1, "uint16")
val_bf16 = tir.reinterpret("bfloat16", val_bf16 = tir.reinterpret(
((((s << tir.const(8, "uint16")) | e_bf16) << tir.const(7, "uint16")) "bfloat16",
| (m_f4 << tir.const(6, "uint16"))).astype("uint16")) ((((s << tir.const(8, "uint16")) | e_bf16) << tir.const(7, "uint16")) | (m_f4 << tir.const(6, "uint16"))).astype("uint16"),
)
return val_bf16 return val_bf16
...@@ -65,6 +65,7 @@ def get_configs(): ...@@ -65,6 +65,7 @@ def get_configs():
List[dict]: A list of configuration dictionaries covering all combinations. List[dict]: A list of configuration dictionaries covering all combinations.
""" """
import itertools import itertools
iter_params = dict( iter_params = dict(
block_M=[64, 128, 256], block_M=[64, 128, 256],
block_N=[64, 128, 256], block_N=[64, 128, 256],
...@@ -73,67 +74,71 @@ def get_configs(): ...@@ -73,67 +74,71 @@ def get_configs():
threads=[128, 256, 512], threads=[128, 256, 512],
split=[1, 2], split=[1, 2],
) )
return [{ return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())]
k: v for k, v in zip(iter_params, values)
} for values in itertools.product(*iter_params.values())]
@tilelang.autotune(
configs=get_configs(),
@tilelang.autotune(configs=get_configs(),) )
@tilelang.jit(out_idx=[-1],) @tilelang.jit(
def matmul(M, out_idx=[-1],
N, )
K, def matmul(
in_dtype, M,
out_dtype, N,
accum_dtype, K,
source_format='uint', in_dtype,
num_bits=4, out_dtype,
scale_size=32, accum_dtype,
fast_dequant=True, source_format="uint",
with_bias=False, num_bits=4,
block_M=256, scale_size=32,
block_N=128, fast_dequant=True,
block_K=128, with_bias=False,
num_stages=2, block_M=256,
threads=256, block_N=128,
split=1): block_K=128,
num_stages=2,
threads=256,
split=1,
):
""" """
Construct and return a tiled matrix-multiply TIR kernel that multiplies A (shape MxK) by a quantized B (shape Nx(QK)) and writes an MxN output in out_dtype. Construct and return a tiled matrix-multiply TIR kernel that multiplies A (shape MxK) by a quantized B (shape Nx(QK)) and writes an MxN output in out_dtype.
The generated kernel accepts:
- A: dense matrix with element type `in_dtype`.
- B: packed quantized matrix stored as uint8 with `num_bits` bits per element (QK = K / (8/num_bits)).
- Scale: per-block scale/exponent information used to dequantize B.
The kernel dequantizes B to a working floating format (out_dtype/accum_dtype) using one of two paths:
- fast_dequant (True): uses an external, hardware/implementation-specific intrinsic group (twiddling) for batch dequantization.
- fast_dequant (False): uses a simple elementwise dequantization helper.
Parameters:
M, N, K (int): matrix dimensions (A is MxK, result is MxN). K must be divisible by (block_K * split).
in_dtype (str): element type of A (e.g., "fp4" in this file).
out_dtype (str): output tensor element type (e.g., "bfloat16").
accum_dtype (str): accumulation type used for the inner GEMM.
source_format (str, optional): format string passed to intrinsic selector (default "uint").
num_bits (int, optional): number of bits per quantized element in B (default 4).
scale_size (int, optional): number of elements grouped per scale entry (default 32).
fast_dequant (bool, optional): choose the fast intrinsic dequantization path when available (default True).
block_M, block_N, block_K (int, optional): tile sizes for M, N, and K dimensions (defaults 256, 128, 128).
num_stages (int, optional): pipelining stages for K loop (default 2).
threads (int, optional): threads per block used by the kernel (default 256).
split (int, optional): split factor along K used by the scheduler (default 1).
with_bias (bool, optional): whether to add Bias to the output (default False).
Returns:
A T.prim_func implementing the tiled, pipelined GEMM that:
- loads tiled blocks of A and packed B to shared memory,
- dequantizes B via the chosen path into a shared dequantized tile,
- performs a tiled GEMM accumulating into local fragments,
- writes the final MxN block to the global output tensor.
Notes: The generated kernel accepts:
- The function queries an intrinsic group to obtain a fast dequantization implementation when fast_dequant is enabled; that intrinsic must supply a valid C source and function name. - A: dense matrix with element type `in_dtype`.
- The kernel layout uses swizzled shared-memory layouts for A, B, and the shared C tile. - B: packed quantized matrix stored as uint8 with `num_bits` bits per element (QK = K / (8/num_bits)).
- An assertion enforces that K % (block_K * split) == 0. - Scale: per-block scale/exponent information used to dequantize B.
The kernel dequantizes B to a working floating format (out_dtype/accum_dtype) using one of two paths:
- fast_dequant (True): uses an external, hardware/implementation-specific intrinsic group (twiddling) for batch dequantization.
- fast_dequant (False): uses a simple elementwise dequantization helper.
Parameters:
M, N, K (int): matrix dimensions (A is MxK, result is MxN). K must be divisible by (block_K * split).
in_dtype (str): element type of A (e.g., "fp4" in this file).
out_dtype (str): output tensor element type (e.g., "bfloat16").
accum_dtype (str): accumulation type used for the inner GEMM.
source_format (str, optional): format string passed to intrinsic selector (default "uint").
num_bits (int, optional): number of bits per quantized element in B (default 4).
scale_size (int, optional): number of elements grouped per scale entry (default 32).
fast_dequant (bool, optional): choose the fast intrinsic dequantization path when available (default True).
block_M, block_N, block_K (int, optional): tile sizes for M, N, and K dimensions (defaults 256, 128, 128).
num_stages (int, optional): pipelining stages for K loop (default 2).
threads (int, optional): threads per block used by the kernel (default 256).
split (int, optional): split factor along K used by the scheduler (default 1).
with_bias (bool, optional): whether to add Bias to the output (default False).
Returns:
A T.prim_func implementing the tiled, pipelined GEMM that:
- loads tiled blocks of A and packed B to shared memory,
- dequantizes B via the chosen path into a shared dequantized tile,
- performs a tiled GEMM accumulating into local fragments,
- writes the final MxN block to the global output tensor.
Notes:
- The function queries an intrinsic group to obtain a fast dequantization implementation when fast_dequant is enabled; that intrinsic must supply a valid C source and function name.
- The kernel layout uses swizzled shared-memory layouts for A, B, and the shared C tile.
- An assertion enforces that K % (block_K * split) == 0.
""" """
num_elems_per_byte = 8 // num_bits num_elems_per_byte = 8 // num_bits
storage_dtype = "uint8" storage_dtype = "uint8"
...@@ -150,6 +155,7 @@ def matmul(M, ...@@ -150,6 +155,7 @@ def matmul(M,
assert K % (block_K * split) == 0 assert K % (block_K * split) == 0
from tilelang.quantize import get_mxfp_intrin_group from tilelang.quantize import get_mxfp_intrin_group
# fast_dequant_bf16_fp4_twiddling # fast_dequant_bf16_fp4_twiddling
mxfp_intrin_info = get_mxfp_intrin_group( mxfp_intrin_info = get_mxfp_intrin_group(
out_dtype=in_dtype, out_dtype=in_dtype,
...@@ -252,8 +258,7 @@ def matmul(M, ...@@ -252,8 +258,7 @@ def matmul(M,
for v in T.vectorized(0, local_size): for v in T.vectorized(0, local_size):
index = i * threads * local_size + tx * local_size + v index = i * threads * local_size + tx * local_size + v
B_dequantize_shared[index // block_K, B_dequantize_shared[index // block_K, index % block_K] = B_dequantize_local_thread[v]
index % block_K] = B_dequantize_local_thread[v]
return fast_dequant_bf16_fp4_twiddling return fast_dequant_bf16_fp4_twiddling
...@@ -301,8 +306,8 @@ def matmul(M, ...@@ -301,8 +306,8 @@ def matmul(M,
B_local[i, j // num_elems_per_byte], B_local[i, j // num_elems_per_byte],
j % num_elems_per_byte, j % num_elems_per_byte,
Scale_shared[ Scale_shared[
i, k * block_K // scale_size + j // i, k * block_K // scale_size + j // scale_size
scale_size], # Scale is the exponential part, within the representation of uint8 ], # Scale is the exponential part, within the representation of uint8
dtype=out_dtype, dtype=out_dtype,
) * T.shift_left(1, (Scale_shared[i, k * block_K // scale_size + j // scale_size])) ) * T.shift_left(1, (Scale_shared[i, k * block_K // scale_size + j // scale_size]))
T.copy(B_dequantize_local, B_dequantize_shared) T.copy(B_dequantize_local, B_dequantize_shared)
...@@ -311,22 +316,22 @@ def matmul(M, ...@@ -311,22 +316,22 @@ def matmul(M,
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, storage_dtype), B: T.Tensor(B_shape, storage_dtype),
Scale: T.Tensor(Scale_shape, storage_dtype), Scale: T.Tensor(Scale_shape, storage_dtype),
Bias: T.Tensor(Bias_shape, out_dtype), Bias: T.Tensor(Bias_shape, out_dtype),
C: T.Tensor((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
""" """
Tiled, pipelined kernel entry that multiplies A with a quantized B (with per-block Scale) producing C. Tiled, pipelined kernel entry that multiplies A with a quantized B (with per-block Scale) producing C.
This prim-level kernel implements a blocked, multi-threaded matmul: it loads tiles of A and the packed/quantized B into shared memory, dequantizes B (either via the fast intrinsic twiddling path or the simple per-element path), performs a block GEMM (with B transposed), and writes the accumulated block results into the output tensor C. The kernel allocates shared buffers for A, B, and the dequantized B, and a local fragment for accumulation; it runs over K in pipelined stages and expects the provided shapes and dtypes to match the tiling parameters used to build the function. This prim-level kernel implements a blocked, multi-threaded matmul: it loads tiles of A and the packed/quantized B into shared memory, dequantizes B (either via the fast intrinsic twiddling path or the simple per-element path), performs a block GEMM (with B transposed), and writes the accumulated block results into the output tensor C. The kernel allocates shared buffers for A, B, and the dequantized B, and a local fragment for accumulation; it runs over K in pipelined stages and expects the provided shapes and dtypes to match the tiling parameters used to build the function.
Parameters are self-descriptive in the signature; notable behaviors: Parameters are self-descriptive in the signature; notable behaviors:
- B is stored in a compact uint8-packed layout (num_bits per element) and is dequantized using Scale before GEMM. - B is stored in a compact uint8-packed layout (num_bits per element) and is dequantized using Scale before GEMM.
- The selected dequantization path is controlled by the outer-scope flag `fast_dequant`. - The selected dequantization path is controlled by the outer-scope flag `fast_dequant`.
- The GEMM uses transpose_B=True (i.e., multiplies A · B^T after dequantization). - The GEMM uses transpose_B=True (i.e., multiplies A · B^T after dequantization).
- The function writes results in-place into C. - The function writes results in-place into C.
""" """
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
...@@ -339,16 +344,20 @@ def matmul(M, ...@@ -339,16 +344,20 @@ def matmul(M,
# May use much more shared memory than necessary # May use much more shared memory than necessary
Scale_shared = T.alloc_shared((block_N, K // scale_size), storage_dtype) Scale_shared = T.alloc_shared((block_N, K // scale_size), storage_dtype)
T.annotate_layout({ T.annotate_layout(
A_shared: tilelang.layout.make_swizzled_layout(A_shared), {
B_shared: tilelang.layout.make_swizzled_layout(B_shared), A_shared: tilelang.layout.make_swizzled_layout(A_shared),
C_shared: tilelang.layout.make_swizzled_layout(C_shared), B_shared: tilelang.layout.make_swizzled_layout(B_shared),
}) C_shared: tilelang.layout.make_swizzled_layout(C_shared),
}
)
if with_bias: if with_bias:
T.annotate_layout({ T.annotate_layout(
Bias_shared: tilelang.layout.make_swizzled_layout(Bias_shared), {
}) Bias_shared: tilelang.layout.make_swizzled_layout(Bias_shared),
}
)
if threads == 512: if threads == 512:
T.disable_warp_group_reg_alloc() T.disable_warp_group_reg_alloc()
...@@ -357,26 +366,24 @@ def matmul(M, ...@@ -357,26 +366,24 @@ def matmul(M,
# T.copy(Bias[by * block_M:(by + 1) * block_M, bx * block_N:(bx + 1) * block_N], # T.copy(Bias[by * block_M:(by + 1) * block_M, bx * block_N:(bx + 1) * block_N],
# Bias_shared) # Bias_shared)
# T.copy(Bias_shared, C_local) # T.copy(Bias_shared, C_local)
T.copy(Bias[by * block_M:(by + 1) * block_M, bx * block_N:(bx + 1) * block_N], T.copy(Bias[by * block_M : (by + 1) * block_M, bx * block_N : (bx + 1) * block_N], C_local)
C_local)
else: else:
T.clear(C_local) T.clear(C_local)
# Use 1D TMA to load Scale # Use 1D TMA to load Scale
T.copy(Scale[bx * block_N:(bx + 1) * block_N, :], Scale_shared) T.copy(Scale[bx * block_N : (bx + 1) * block_N, :], Scale_shared)
for k in T.Pipelined(K // block_K, num_stages=num_stages): for k in T.Pipelined(K // block_K, num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_shared) T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared) T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared)
if fast_dequant: if fast_dequant:
get_fast_dequant_twiddling_func()(B_shared, B_dequantize_shared, Scale_shared, get_fast_dequant_twiddling_func()(B_shared, B_dequantize_shared, Scale_shared, k)
k)
else: else:
get_simple_dequant_func()(B_shared, B_dequantize_shared, Scale_shared, k) get_simple_dequant_func()(B_shared, B_dequantize_shared, Scale_shared, k)
T.gemm(A_shared, B_dequantize_shared, C_local, transpose_B=True) T.gemm(A_shared, B_dequantize_shared, C_local, transpose_B=True)
T.copy(C_local, C_shared) T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M:(by + 1) * block_M, bx * block_N:(bx + 1) * block_N]) T.copy(C_shared, C[by * block_M : (by + 1) * block_M, bx * block_N : (bx + 1) * block_N])
return main return main
...@@ -399,7 +406,7 @@ def ref_program_twiddling(A, qB, Scale, Bias=None): ...@@ -399,7 +406,7 @@ def ref_program_twiddling(A, qB, Scale, Bias=None):
B = torch_convert_bit_twiddling(qB) B = torch_convert_bit_twiddling(qB)
for i in range(B.shape[0]): for i in range(B.shape[0]):
for j in range(B.shape[1]): for j in range(B.shape[1]):
B[i][j] = B[i][j] * (2**(Scale[i][j // 32])) B[i][j] = B[i][j] * (2 ** (Scale[i][j // 32]))
C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
C = C.to(torch.__getattribute__(dtypeC)) C = C.to(torch.__getattribute__(dtypeC))
return C return C
...@@ -424,7 +431,7 @@ def ref_program_twiddling_with_bias(A, qB, Scale, Bias): ...@@ -424,7 +431,7 @@ def ref_program_twiddling_with_bias(A, qB, Scale, Bias):
B = torch_convert_bit_twiddling(qB) B = torch_convert_bit_twiddling(qB)
for i in range(B.shape[0]): for i in range(B.shape[0]):
for j in range(B.shape[1]): for j in range(B.shape[1]):
B[i][j] = B[i][j] * (2**(Scale[i][j // 32])) B[i][j] = B[i][j] * (2 ** (Scale[i][j // 32]))
C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + Bias C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + Bias
C = C.to(torch.__getattribute__(dtypeC)) C = C.to(torch.__getattribute__(dtypeC))
return C return C
...@@ -450,7 +457,7 @@ def ref_program_simple(A, qB, Scale, Bias=None): ...@@ -450,7 +457,7 @@ def ref_program_simple(A, qB, Scale, Bias=None):
B = torch_convert(qB) B = torch_convert(qB)
for i in range(B.shape[0]): for i in range(B.shape[0]):
for j in range(B.shape[1]): for j in range(B.shape[1]):
B[i][j] = B[i][j] * (2**(Scale[i][j // 32])) B[i][j] = B[i][j] * (2 ** (Scale[i][j // 32]))
C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
C = C.to(torch.__getattribute__(dtypeC)) C = C.to(torch.__getattribute__(dtypeC))
return C return C
...@@ -480,7 +487,7 @@ def ref_program_simple_with_bias(A, qB, Scale, Bias): ...@@ -480,7 +487,7 @@ def ref_program_simple_with_bias(A, qB, Scale, Bias):
B = torch_convert(qB) B = torch_convert(qB)
for i in range(B.shape[0]): for i in range(B.shape[0]):
for j in range(B.shape[1]): for j in range(B.shape[1]):
B[i][j] = B[i][j] * (2**(Scale[i][j // 32])) B[i][j] = B[i][j] * (2 ** (Scale[i][j // 32]))
C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + Bias C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + Bias
C = C.to(torch.__getattribute__(dtypeC)) C = C.to(torch.__getattribute__(dtypeC))
return C return C
...@@ -507,16 +514,8 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False, ...@@ -507,16 +514,8 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False,
if tune: if tune:
kernel = matmul( kernel = matmul(
m, m, n, k, "bfloat16", "bfloat16", "float32", num_bits=4, scale_size=scale_size, fast_dequant=fast_dequant, with_bias=with_bias
n, )
k,
"bfloat16",
"bfloat16",
"float32",
num_bits=4,
scale_size=scale_size,
fast_dequant=fast_dequant,
with_bias=with_bias)
else: else:
kernel = matmul( kernel = matmul(
m, m,
...@@ -534,7 +533,8 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False, ...@@ -534,7 +533,8 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False,
threads=256, threads=256,
split=1, split=1,
fast_dequant=fast_dequant, fast_dequant=fast_dequant,
with_bias=with_bias) with_bias=with_bias,
)
profiler = kernel.get_profiler(tilelang.TensorSupplyType.Auto) profiler = kernel.get_profiler(tilelang.TensorSupplyType.Auto)
......
...@@ -24,6 +24,7 @@ def matmul( ...@@ -24,6 +24,7 @@ def matmul(
num_bits=4, num_bits=4,
): ):
from tilelang.quantize import _tir_packed_to_unsigned_convert from tilelang.quantize import _tir_packed_to_unsigned_convert
num_elems_per_byte = 8 // num_bits num_elems_per_byte = 8 // num_bits
storage_dtype = "int8" storage_dtype = "int8"
storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) storage_nbit = int("".join(c for c in storage_dtype if c.isdigit()))
...@@ -39,9 +40,9 @@ def matmul( ...@@ -39,9 +40,9 @@ def matmul(
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, storage_dtype), B: T.Tensor(B_shape, storage_dtype),
C: T.Tensor((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
...@@ -58,21 +59,19 @@ def matmul( ...@@ -58,21 +59,19 @@ def matmul(
T.copy(A[by * block_M, k * block_K], A_shared) T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared) T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared)
for i in T.serial(block_N * block_K // num_elems_per_byte // for i in T.serial(block_N * block_K // num_elems_per_byte // (threads * local_size_compressed)):
(threads * local_size_compressed)):
for v in T.vectorized(0, local_size_compressed): for v in T.vectorized(0, local_size_compressed):
index = i * threads * local_size_compressed + tx * local_size_compressed + v index = i * threads * local_size_compressed + tx * local_size_compressed + v
vi = index // (block_K // num_elems_per_byte) vi = index // (block_K // num_elems_per_byte)
vj = index % (block_K // num_elems_per_byte) vj = index % (block_K // num_elems_per_byte)
B_local[v] = B_shared[vi, vj] B_local[v] = B_shared[vi, vj]
for v in T.serial(0, local_size): for v in T.serial(0, local_size):
B_dequantize_local[v] = _tir_packed_to_unsigned_convert( B_dequantize_local[v] = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)(
storage_type, storage_nbit)( num_bits,
num_bits, B_local[v // num_elems_per_byte],
B_local[v // num_elems_per_byte], v % num_elems_per_byte,
v % num_elems_per_byte, dtype=in_dtype,
dtype=in_dtype, )
)
for v in T.vectorized(0, local_size): for v in T.vectorized(0, local_size):
index = i * threads * local_size + tx * local_size + v index = i * threads * local_size + tx * local_size + v
vi = index // block_K vi = index // block_K
...@@ -121,9 +120,7 @@ def run_gemm( ...@@ -121,9 +120,7 @@ def run_gemm(
def ref_program(A, qB): def ref_program(A, qB):
import torch import torch
B = ( B = torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4, dtype=torch.half).to(torch.half).to(A.device)
torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4,
dtype=torch.half).to(torch.half).to(A.device))
for i in range(B.shape[0]): for i in range(B.shape[0]):
for j in range(B.shape[1]): for j in range(B.shape[1]):
B[i][j] = ((qB[i][j // 2] >> (4 * (j % 2))) & 0xF).to(torch.half) B[i][j] = ((qB[i][j // 2] >> (4 * (j % 2))) & 0xF).to(torch.half)
...@@ -146,9 +143,11 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( ...@@ -146,9 +143,11 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(
): ):
from tilelang.intrinsics.mma_layout import make_mma_swizzle_layout as make_swizzle_layout from tilelang.intrinsics.mma_layout import make_mma_swizzle_layout as make_swizzle_layout
from tilelang.intrinsics.mma_macro_generator import ( from tilelang.intrinsics.mma_macro_generator import (
TensorCoreIntrinEmitterWithLadderTransform,) TensorCoreIntrinEmitterWithLadderTransform,
)
from bitblas.gpu.intrin.lop3 import decode_i4_to_f16 from bitblas.gpu.intrin.lop3 import decode_i4_to_f16
assert in_dtype in [ assert in_dtype in [
"float16", "float16",
"int8", "int8",
...@@ -192,8 +191,7 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( ...@@ -192,8 +191,7 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(
pad_factor = 8 pad_factor = 8
A_shape = (M, K) A_shape = (M, K)
B_shape = (N // micro_size_y, K // micro_size_k, micro_size_y, B_shape = (N // micro_size_y, K // micro_size_k, micro_size_y, micro_size_k // num_elems_per_byte)
micro_size_k // num_elems_per_byte)
A_shared_shape = (block_M, (block_K + pad_factor) if apply_pad_a else block_K) A_shared_shape = (block_M, (block_K + pad_factor) if apply_pad_a else block_K)
B_shared_shape = ( B_shared_shape = (
block_N // micro_size_y, block_N // micro_size_y,
...@@ -228,7 +226,8 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( ...@@ -228,7 +226,8 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(
chunk=chunk, chunk=chunk,
reduce_k=reduce_k, reduce_k=reduce_k,
transform_kind_b=transform_b, transform_kind_b=transform_b,
num_elems_per_byte=num_elems_per_byte) num_elems_per_byte=num_elems_per_byte,
)
vec_load_qb = 16 vec_load_qb = 16
if block_N * (block_K // reduce_k) // num_elems_per_byte // threads < vec_load_qb: if block_N * (block_K // reduce_k) // num_elems_per_byte // threads < vec_load_qb:
...@@ -236,14 +235,11 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( ...@@ -236,14 +235,11 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, storage_dtype), B: T.Tensor(B_shape, storage_dtype),
C: T.Tensor((M, N), out_dtype), C: T.Tensor((M, N), out_dtype),
): ):
with T.Kernel( with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads, prelude=decode_i4_to_f16) as (bx, by):
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads,
prelude=decode_i4_to_f16) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, storage_dtype, scope=shared_scope) B_shared = T.alloc_shared(B_shared_shape, storage_dtype, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
...@@ -255,40 +251,36 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( ...@@ -255,40 +251,36 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(
thread_binding = T.get_thread_binding(0) thread_binding = T.get_thread_binding(0)
rk = T.get_thread_binding(1) rk = T.get_thread_binding(1)
T.annotate_layout({ T.annotate_layout(
A_shared: make_swizzle_layout(A_shared), {
}) A_shared: make_swizzle_layout(A_shared),
}
)
T.use_swizzle(panel_size=10) T.use_swizzle(panel_size=10)
T.clear(C_local) T.clear(C_local)
for ko in T.Pipelined((K // block_K), num_stages=stage): for ko in T.Pipelined((K // block_K), num_stages=stage):
# Load A into shared memory # Load A into shared memory
for i, k in T.Parallel(block_M, (block_K // reduce_k)): for i, k in T.Parallel(block_M, (block_K // reduce_k)):
vk = rk * (block_K // reduce_k) + k vk = rk * (block_K // reduce_k) + k
A_shared[i, vk] = A[by * block_M + i, ko * block_K + vk] A_shared[i, vk] = A[by * block_M + i, ko * block_K + vk]
# TODO(lei): Layout Inference Pass is not efficient to handle the four dims int8 load # TODO(lei): Layout Inference Pass is not efficient to handle the four dims int8 load
for i in T.serial(block_N * (block_K // reduce_k) // num_elems_per_byte // for i in T.serial(block_N * (block_K // reduce_k) // num_elems_per_byte // (threads * vec_load_qb)):
(threads * vec_load_qb)):
for v in T.vectorized(0, vec_load_qb): for v in T.vectorized(0, vec_load_qb):
t = thread_binding t = thread_binding
idx = i * threads * vec_load_qb * reduce_k + rk * threads * vec_load_qb + t * vec_load_qb + v idx = i * threads * vec_load_qb * reduce_k + rk * threads * vec_load_qb + t * vec_load_qb + v
vkk = idx % (micro_size_k // num_elems_per_byte) vkk = idx % (micro_size_k // num_elems_per_byte)
vjj = (idx // (micro_size_k // num_elems_per_byte)) % micro_size_y vjj = (idx // (micro_size_k // num_elems_per_byte)) % micro_size_y
vk = (idx // (micro_size_k // num_elems_per_byte) // micro_size_y) % ( vk = (idx // (micro_size_k // num_elems_per_byte) // micro_size_y) % (block_K // micro_size_k)
block_K // micro_size_k) vj = (idx // (micro_size_k // num_elems_per_byte) // micro_size_y // (block_K // micro_size_k)) % (
vj = (idx // (micro_size_k // num_elems_per_byte) // micro_size_y // block_N // micro_size_y
(block_K // micro_size_k)) % ( )
block_N // micro_size_y) B_shared[vj, vk, vjj, vkk] = B[bx * (block_N // micro_size_y) + vj, ko * (block_K // micro_size_k) + vk, vjj, vkk]
B_shared[vj, vk, vjj,
vkk] = B[bx * (block_N // micro_size_y) + vj,
ko * (block_K // micro_size_k) + vk, vjj, vkk]
for ki in T.serial(0, (block_K // (micro_size_k * reduce_k))): for ki in T.serial(0, (block_K // (micro_size_k * reduce_k))):
# Load A into fragment # Load A into fragment
mma_emitter.ldmatrix_a( mma_emitter.ldmatrix_a(
A_local, A_local,
...@@ -307,9 +299,13 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( ...@@ -307,9 +299,13 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(
for j in T.serial(warp_cols): for j in T.serial(warp_cols):
local_size_b = mma_emitter.local_size_b local_size_b = mma_emitter.local_size_b
T.call_extern('handle', 'decode_i4u_to_f16', T.call_extern(
T.address_of(B_local[j * local_size_b // num_elems_per_byte]), "handle",
T.address_of(B_dequantize_local[j * local_size_b]), 8) "decode_i4u_to_f16",
T.address_of(B_local[j * local_size_b // num_elems_per_byte]),
T.address_of(B_dequantize_local[j * local_size_b]),
8,
)
mma_emitter.mma(A_local, B_dequantize_local, C_local) mma_emitter.mma(A_local, B_dequantize_local, C_local)
...@@ -328,7 +324,8 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( ...@@ -328,7 +324,8 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(
reduced_accum_res[0], reduced_accum_res[0],
rk, rk,
dtype="handle", dtype="handle",
)) )
)
if rk == 0: if rk == 0:
C_local[n] = reduced_accum_res[0] C_local[n] = reduced_accum_res[0]
...@@ -340,9 +337,9 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4( ...@@ -340,9 +337,9 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(
for i, j in T.Parallel(block_M, (block_N // reduce_k)): for i, j in T.Parallel(block_M, (block_N // reduce_k)):
vj = rk * (block_N // reduce_k) + j vj = rk * (block_N // reduce_k) + j
C[by * block_M + i, C[by * block_M + i, bx * block_N + vj] = C_shared[
bx * block_N + vj] = C_shared[i // micro_size_x, vj // micro_size_y, i // micro_size_x, vj // micro_size_y, i % micro_size_x, vj % micro_size_y
i % micro_size_x, vj % micro_size_y] ]
return main return main
...@@ -357,8 +354,8 @@ def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correct ...@@ -357,8 +354,8 @@ def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correct
transform_b, transform_b,
): ):
import bitblas import bitblas
matmul = tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(
M, N, K, in_dtype, out_dtype, accum_dtype, transform_b) matmul = tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(M, N, K, in_dtype, out_dtype, accum_dtype, transform_b)
kernel = tilelang.compile(matmul, out_idx=[2]) kernel = tilelang.compile(matmul, out_idx=[2])
src_code = kernel.get_kernel_source() src_code = kernel.get_kernel_source()
...@@ -371,8 +368,7 @@ def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correct ...@@ -371,8 +368,7 @@ def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correct
storage_dtype = "int8" storage_dtype = "int8"
A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype)) A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype))
qB = torch.randint( qB = torch.randint(0, 127, (N, K // num_elems_per_byte), device="cuda", dtype=getattr(torch, storage_dtype))
0, 127, (N, K // num_elems_per_byte), device="cuda", dtype=getattr(torch, storage_dtype))
C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype)) C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype))
ladder_permutate_config = bitblas.ops.LadderPermutateConfig( ladder_permutate_config = bitblas.ops.LadderPermutateConfig(
...@@ -407,9 +403,7 @@ def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correct ...@@ -407,9 +403,7 @@ def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correct
# Ensure that the latency is not None # Ensure that the latency is not None
assert latency is not None assert latency is not None
B = ( B = torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4, dtype=torch.half).to(torch.half).to(A.device)
torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4,
dtype=torch.half).to(torch.half).to(A.device))
for i in range(B.shape[0]): for i in range(B.shape[0]):
for j in range(B.shape[1]): for j in range(B.shape[1]):
B[i][j] = ((qB[i][j // 2] >> (4 * (j % 2))) & 0xF).to(torch.half) B[i][j] = ((qB[i][j // 2] >> (4 * (j % 2))) & 0xF).to(torch.half)
...@@ -429,8 +423,7 @@ def test_run_dequantize_gemm(): ...@@ -429,8 +423,7 @@ def test_run_dequantize_gemm():
@tilelang.testing.requires_package("bitblas") @tilelang.testing.requires_package("bitblas")
def test_assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(): def test_assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4():
assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correctness( assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correctness(256, 1024, 512, "float16", "float16", "float16", 3)
256, 1024, 512, "float16", "float16", "float16", 3)
def main(): def main():
......
...@@ -21,18 +21,17 @@ def _tir_u8_to_f4_to_f16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: ...@@ -21,18 +21,17 @@ def _tir_u8_to_f4_to_f16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype:
e_f16 = e_f4 + tir.const(14, "uint16") e_f16 = e_f4 + tir.const(14, "uint16")
m_f4 = f4 & tir.const(1, "uint16") m_f4 = f4 & tir.const(1, "uint16")
m_f16 = m_f4 m_f16 = m_f4
val_f16 = tir.reinterpret("float16", val_f16 = tir.reinterpret(
((e_f16 | (s << tir.const(5, "uint16"))) << tir.const(10, "uint16") "float16", ((e_f16 | (s << tir.const(5, "uint16"))) << tir.const(10, "uint16") | m_f16 << tir.const(9, "uint16")).astype("uint16")
| m_f16 << tir.const(9, "uint16")).astype("uint16")) )
# return tir.Select(e_f4 == tir.const(0, "uint32"), tir.const(0, "float16"), val_f16) # return tir.Select(e_f4 == tir.const(0, "uint32"), tir.const(0, "float16"), val_f16)
return val_f16 return val_f16
def torch_convert(tensor): def torch_convert(tensor):
def print_bit(name, val): def print_bit(name, val):
val_cpu = val.cpu().item() val_cpu = val.cpu().item()
binary_repr = f'{val_cpu:032b}' binary_repr = f"{val_cpu:032b}"
print(name, binary_repr) print(name, binary_repr)
def _convert(val, pos): def _convert(val, pos):
...@@ -68,8 +67,8 @@ def test_convert(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128): ...@@ -68,8 +67,8 @@ def test_convert(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128):
@T.prim_func @T.prim_func
def main( def main(
B: T.Tensor(B_shape, storage_dtype), B: T.Tensor(B_shape, storage_dtype),
C: T.Tensor((N, K), in_dtype), C: T.Tensor((N, K), in_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx): with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx):
B_shared = T.alloc_shared(B_shared_shape, storage_dtype) B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
...@@ -118,19 +117,11 @@ def get_configs(): ...@@ -118,19 +117,11 @@ def get_configs():
splits = [1] splits = [1]
_configs = list(itertools.product(block_M, block_N, block_K, num_stages, threads, splits)) _configs = list(itertools.product(block_M, block_N, block_K, num_stages, threads, splits))
configs = [{ configs = [{"block_M": c[0], "block_N": c[1], "block_K": c[2], "num_stages": c[3], "threads": c[4], "split": c[5]} for c in _configs]
'block_M': c[0],
'block_N': c[1],
'block_K': c[2],
'num_stages': c[3],
'threads': c[4],
'split': c[5]
} for c in _configs]
return configs return configs
def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False): def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False):
@tilelang.jit(out_idx=[2]) @tilelang.jit(out_idx=[2])
def kernel_func(block_M, block_N, block_K, num_stages, threads, split=1): def kernel_func(block_M, block_N, block_K, num_stages, threads, split=1):
num_elems_per_byte = 8 // num_bits num_elems_per_byte = 8 // num_bits
...@@ -145,17 +136,12 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False): ...@@ -145,17 +136,12 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False):
@T.prim_func @T.prim_func
def main_split( def main_split(
A: T.Tensor(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, storage_dtype), B: T.Tensor(B_shape, storage_dtype),
Ct: T.Tensor((N, M), out_dtype), Ct: T.Tensor((N, M), out_dtype),
): ):
SplitC = T.alloc_buffer([ SplitC = T.alloc_buffer([split, (N + block_N - 1) // block_N * block_N, (M + block_M - 1) // block_M * block_M], out_dtype)
split, (N + block_N - 1) // block_N * block_N, with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), split, threads=threads) as (bx, by, bz):
(M + block_M - 1) // block_M * block_M
], out_dtype)
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), split,
threads=threads) as (bx, by, bz):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, storage_dtype) B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
B_local = T.alloc_fragment(B_shared_shape, storage_dtype) B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
...@@ -164,10 +150,12 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False): ...@@ -164,10 +150,12 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False):
Ct_local = T.alloc_fragment((block_N, block_M), accum_dtype) Ct_local = T.alloc_fragment((block_N, block_M), accum_dtype)
Ct_shared = T.alloc_shared((block_N, block_M), out_dtype) Ct_shared = T.alloc_shared((block_N, block_M), out_dtype)
T.annotate_layout({ T.annotate_layout(
B_shared: tilelang.layout.make_swizzled_layout(B_shared), {
Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared), B_shared: tilelang.layout.make_swizzled_layout(B_shared),
}) Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared),
}
)
T.clear(Ct_local) T.clear(Ct_local)
for k in T.Pipelined(K // (block_K * split), num_stages=num_stages): for k in T.Pipelined(K // (block_K * split), num_stages=num_stages):
...@@ -183,8 +171,7 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False): ...@@ -183,8 +171,7 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False):
) )
T.copy(B_dequantize_local, B_dequantize_prev_local) T.copy(B_dequantize_local, B_dequantize_prev_local)
T.gemm(B_dequantize_prev_local, A_shared, Ct_local, transpose_B=True) T.gemm(B_dequantize_prev_local, A_shared, Ct_local, transpose_B=True)
T.copy(Ct_local, SplitC[bz, bx * block_N:(bx + 1) * block_N, T.copy(Ct_local, SplitC[bz, bx * block_N : (bx + 1) * block_N, by * block_M : (by + 1) * block_M])
by * block_M:(by + 1) * block_M])
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M)) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M)) as (bx, by):
acc = T.alloc_fragment((block_N, block_M), out_dtype) acc = T.alloc_fragment((block_N, block_M), out_dtype)
T.clear(acc) T.clear(acc)
...@@ -195,12 +182,11 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False): ...@@ -195,12 +182,11 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, storage_dtype), B: T.Tensor(B_shape, storage_dtype),
Ct: T.Tensor((N, M), out_dtype), Ct: T.Tensor((N, M), out_dtype),
): ):
with T.Kernel( with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, storage_dtype) B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
B_local = T.alloc_fragment(B_shared_shape, storage_dtype) B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
...@@ -209,10 +195,12 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False): ...@@ -209,10 +195,12 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False):
Ct_local = T.alloc_fragment((block_N, block_M), accum_dtype) Ct_local = T.alloc_fragment((block_N, block_M), accum_dtype)
Ct_shared = T.alloc_shared((block_N, block_M), out_dtype) Ct_shared = T.alloc_shared((block_N, block_M), out_dtype)
T.annotate_layout({ T.annotate_layout(
B_shared: tilelang.layout.make_swizzled_layout(B_shared), {
Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared), B_shared: tilelang.layout.make_swizzled_layout(B_shared),
}) Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared),
}
)
T.clear(Ct_local) T.clear(Ct_local)
for k in T.Pipelined(K // block_K, num_stages=num_stages): for k in T.Pipelined(K // block_K, num_stages=num_stages):
...@@ -229,8 +217,7 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False): ...@@ -229,8 +217,7 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False):
T.copy(B_dequantize_local, B_dequantize_prev_local) T.copy(B_dequantize_local, B_dequantize_prev_local)
T.gemm(B_dequantize_prev_local, A_shared, Ct_local, transpose_B=True) T.gemm(B_dequantize_prev_local, A_shared, Ct_local, transpose_B=True)
T.copy(Ct_local, Ct_shared) T.copy(Ct_local, Ct_shared)
T.copy(Ct_shared, Ct[bx * block_N:(bx + 1) * block_N, T.copy(Ct_shared, Ct[bx * block_N : (bx + 1) * block_N, by * block_M : (by + 1) * block_M])
by * block_M:(by + 1) * block_M])
if split == 1: if split == 1:
return main return main
...@@ -241,12 +228,7 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False): ...@@ -241,12 +228,7 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False):
@autotune(configs=get_configs(), warmup=10, rep=10) @autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[2]) @tilelang.jit(out_idx=[2])
def kernel(block_M=None, def kernel(block_M=None, block_N=None, block_K=None, num_stages=None, threads=None, split=None):
block_N=None,
block_K=None,
num_stages=None,
threads=None,
split=None):
return kernel_func(block_M, block_N, block_K, num_stages, threads, split).prim_func return kernel_func(block_M, block_N, block_K, num_stages, threads, split).prim_func
return kernel() return kernel()
...@@ -269,10 +251,10 @@ def ref_program(A, qB): ...@@ -269,10 +251,10 @@ def ref_program(A, qB):
def main(m=256, n=256, k=256, tune=False): def main(m=256, n=256, k=256, tune=False):
total_flops = 2 * m * n * k total_flops = 2 * m * n * k
if (not tune): if not tune:
kernel = matmul( kernel = matmul(m, n, k, "float16", "float16", "float32", num_bits=4, tune=tune)(
m, n, k, "float16", "float16", "float32", num_bits=4, tune=tune)( block_M=128, block_N=128, block_K=128, num_stages=2, threads=256, split=1
block_M=128, block_N=128, block_K=128, num_stages=2, threads=256, split=1) )
profiler = kernel.get_profiler(tilelang.TensorSupplyType.Integer) profiler = kernel.get_profiler(tilelang.TensorSupplyType.Integer)
profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
print("All checks pass.") print("All checks pass.")
...@@ -293,10 +275,10 @@ def main(m=256, n=256, k=256, tune=False): ...@@ -293,10 +275,10 @@ def main(m=256, n=256, k=256, tune=False):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--m', type=int, default=256, help='M') parser.add_argument("--m", type=int, default=256, help="M")
parser.add_argument('--n', type=int, default=256, help='N') parser.add_argument("--n", type=int, default=256, help="N")
parser.add_argument('--k', type=int, default=256, help='K') parser.add_argument("--k", type=int, default=256, help="K")
parser.add_argument('--tune', action='store_true', help='tune configs') parser.add_argument("--tune", action="store_true", help="tune configs")
args = parser.parse_args() args = parser.parse_args()
M, N, K = args.m, args.n, args.k M, N, K = args.m, args.n, args.k
main(M, N, K, args.tune) main(M, N, K, args.tune)
...@@ -42,8 +42,8 @@ def _convert_test(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128): ...@@ -42,8 +42,8 @@ def _convert_test(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128):
@T.prim_func @T.prim_func
def main( def main(
B: T.Tensor(B_shape, storage_dtype), B: T.Tensor(B_shape, storage_dtype),
C: T.Tensor((N, K), in_dtype), C: T.Tensor((N, K), in_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx): with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx):
B_shared = T.alloc_shared(B_shared_shape, storage_dtype) B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
...@@ -66,13 +66,12 @@ def _convert_test(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128): ...@@ -66,13 +66,12 @@ def _convert_test(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128):
def torch_convert(tensor): def torch_convert(tensor):
def _convert(val, pos): def _convert(val, pos):
assert val.dtype == torch.uint8 assert val.dtype == torch.uint8
val = val.view(torch.int8) val = val.view(torch.int8)
mask = (1 << 4) - 1 mask = (1 << 4) - 1
i4_shifted = ((val >> (pos * 4)) & mask) i4_shifted = (val >> (pos * 4)) & mask
i4 = ((i4_shifted << 4) >> 4) i4 = (i4_shifted << 4) >> 4
return i4.view(torch.int8) return i4.view(torch.int8)
...@@ -94,7 +93,6 @@ def ref_program(A, qB): ...@@ -94,7 +93,6 @@ def ref_program(A, qB):
def matmul_int8xint4(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False): def matmul_int8xint4(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False):
@tilelang.jit(out_idx=[2]) @tilelang.jit(out_idx=[2])
def kernel_func(block_M, block_N, block_K, num_stages, threads): def kernel_func(block_M, block_N, block_K, num_stages, threads):
num_elems_per_byte = 8 // num_bits num_elems_per_byte = 8 // num_bits
...@@ -109,12 +107,11 @@ def matmul_int8xint4(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune ...@@ -109,12 +107,11 @@ def matmul_int8xint4(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, storage_dtype), B: T.Tensor(B_shape, storage_dtype),
Ct: T.Tensor((N, M), out_dtype), Ct: T.Tensor((N, M), out_dtype),
): ):
with T.Kernel( with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, storage_dtype) B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
B_local = T.alloc_fragment(B_shared_shape, storage_dtype) B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
...@@ -123,10 +120,12 @@ def matmul_int8xint4(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune ...@@ -123,10 +120,12 @@ def matmul_int8xint4(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune
Ct_local = T.alloc_fragment((block_N, block_M), accum_dtype) Ct_local = T.alloc_fragment((block_N, block_M), accum_dtype)
Ct_shared = T.alloc_shared((block_N, block_M), out_dtype) Ct_shared = T.alloc_shared((block_N, block_M), out_dtype)
T.annotate_layout({ T.annotate_layout(
B_shared: tilelang.layout.make_swizzled_layout(B_shared), {
Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared), B_shared: tilelang.layout.make_swizzled_layout(B_shared),
}) Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared),
}
)
T.clear(Ct_local) T.clear(Ct_local)
for k in T.Pipelined(K // block_K, num_stages=num_stages): for k in T.Pipelined(K // block_K, num_stages=num_stages):
...@@ -143,8 +142,7 @@ def matmul_int8xint4(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune ...@@ -143,8 +142,7 @@ def matmul_int8xint4(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune
T.copy(B_dequantize_local, B_dequantize_prev_local) T.copy(B_dequantize_local, B_dequantize_prev_local)
T.gemm(B_dequantize_prev_local, A_shared, Ct_local, transpose_B=True) T.gemm(B_dequantize_prev_local, A_shared, Ct_local, transpose_B=True)
T.copy(Ct_local, Ct_shared) T.copy(Ct_local, Ct_shared)
T.copy(Ct_shared, Ct[bx * block_N:(bx + 1) * block_N, T.copy(Ct_shared, Ct[bx * block_N : (bx + 1) * block_N, by * block_M : (by + 1) * block_M])
by * block_M:(by + 1) * block_M])
return main return main
...@@ -167,10 +165,10 @@ def matmul_int8xint4(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune ...@@ -167,10 +165,10 @@ def matmul_int8xint4(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune
def main(m=128, n=256, k=256, tune=False): def main(m=128, n=256, k=256, tune=False):
total_flops = 2 * m * n * k total_flops = 2 * m * n * k
if (not tune): if not tune:
kernel = matmul_int8xint4( kernel = matmul_int8xint4(m, n, k, "int8", "int32", "int32", num_bits=4, tune=tune)(
m, n, k, "int8", "int32", "int32", num_bits=4, tune=tune)( block_M=32, block_N=32, block_K=128, num_stages=1, threads=128
block_M=32, block_N=32, block_K=128, num_stages=1, threads=128) )
profiler = kernel.get_profiler() profiler = kernel.get_profiler()
profiler.assert_allclose(ref_program, rtol=1e-2, atol=1e-2) profiler.assert_allclose(ref_program, rtol=1e-2, atol=1e-2)
print("All checks pass.") print("All checks pass.")
......
...@@ -4,7 +4,8 @@ from typing import Optional, Callable, Any ...@@ -4,7 +4,8 @@ from typing import Optional, Callable, Any
import torch import torch
from tilelang import DataType from tilelang import DataType
from tilelang.quantize import ( from tilelang.quantize import (
_tir_packed_int_to_int_convert,) _tir_packed_int_to_int_convert,
)
@tilelang.jit @tilelang.jit
...@@ -26,11 +27,10 @@ def dequantize_gemv( ...@@ -26,11 +27,10 @@ def dequantize_gemv(
group_size: int = -1, group_size: int = -1,
with_scaling: bool = False, with_scaling: bool = False,
) -> Callable[..., Any]: ) -> Callable[..., Any]:
assert n_partition is not None, "n_partition must be provided" assert n_partition is not None, "n_partition must be provided"
assert reduce_thread is not None, ( assert reduce_thread is not None, (
"reduce_thread must be provided currently, as related bitblas.gpu.gemv.GEMV" "reduce_thread must be provided currently, as related bitblas.gpu.gemv.GEMVsch_outer_reduction_with_config is not implemented"
"sch_outer_reduction_with_config is not implemented") )
assert trans_A is False, "Dequantize only implement for trans_A=False currently" assert trans_A is False, "Dequantize only implement for trans_A=False currently"
assert trans_B is True, "Dequantize only implement for trans_B=TRue currently" assert trans_B is True, "Dequantize only implement for trans_B=TRue currently"
...@@ -81,12 +81,12 @@ def dequantize_gemv( ...@@ -81,12 +81,12 @@ def dequantize_gemv(
C: T.Tensor[C_shape, out_dtype], C: T.Tensor[C_shape, out_dtype],
): ):
with T.Kernel( with T.Kernel(
T.ceildiv(N, n_partition), T.ceildiv(N, n_partition),
M, M,
threads=(reduce_thread, n_partition), threads=(reduce_thread, n_partition),
) as ( ) as (
bx, bx,
by, by,
): ):
A_local = T.alloc_local((micro_size_k,), in_dtype) A_local = T.alloc_local((micro_size_k,), in_dtype)
B_quant_local = T.alloc_local([micro_size_k_compressed], storage_dtype) B_quant_local = T.alloc_local([micro_size_k_compressed], storage_dtype)
...@@ -107,8 +107,7 @@ def dequantize_gemv( ...@@ -107,8 +107,7 @@ def dequantize_gemv(
for v in T.vectorized(micro_size_k_compressed): for v in T.vectorized(micro_size_k_compressed):
B_quant_local[v] = B[ B_quant_local[v] = B[
bx * n_partition + ni, bx * n_partition + ni,
ko * (reduce_thread * micro_size_k_compressed) + ko * (reduce_thread * micro_size_k_compressed) + kr * micro_size_k_compressed + v,
kr * micro_size_k_compressed + v,
] ]
if fast_decoding: if fast_decoding:
...@@ -120,10 +119,9 @@ def dequantize_gemv( ...@@ -120,10 +119,9 @@ def dequantize_gemv(
) )
else: else:
for ki in T.serial(micro_size_k): for ki in T.serial(micro_size_k):
B_dequantize_local[ki] = _tir_packed_int_to_int_convert( B_dequantize_local[ki] = _tir_packed_int_to_int_convert(storage_type, storage_nbit)(
storage_type, num_bits, B_quant_local[ki // num_elems_per_byte], ki % num_elems_per_byte, in_dtype
storage_nbit)(num_bits, B_quant_local[ki // num_elems_per_byte], )
ki % num_elems_per_byte, in_dtype)
if use_dp4a: if use_dp4a:
for ki in T.serial(micro_size_k // dp4a_size): for ki in T.serial(micro_size_k // dp4a_size):
...@@ -137,9 +135,9 @@ def dequantize_gemv( ...@@ -137,9 +135,9 @@ def dequantize_gemv(
accum_res[0] += A_local[ki] * B_dequantize_local[ki] accum_res[0] += A_local[ki] * B_dequantize_local[ki]
with T.attr( with T.attr(
T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]), T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]),
"reduce_scope", "reduce_scope",
T.reinterpret(T.uint64(0), dtype="handle"), T.reinterpret(T.uint64(0), dtype="handle"),
): ):
T.evaluate( T.evaluate(
T.tvm_thread_allreduce( T.tvm_thread_allreduce(
...@@ -149,7 +147,8 @@ def dequantize_gemv( ...@@ -149,7 +147,8 @@ def dequantize_gemv(
reduced_accum_res[0], reduced_accum_res[0],
kr, kr,
dtype="handle", dtype="handle",
)) )
)
if kr == 0: if kr == 0:
C[by, bx * n_partition + ni] = reduced_accum_res[0] C[by, bx * n_partition + ni] = reduced_accum_res[0]
...@@ -174,26 +173,39 @@ def main() -> None: ...@@ -174,26 +173,39 @@ def main() -> None:
group_size = -1 group_size = -1
with_scaling = False with_scaling = False
kernel = dequantize_gemv(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits, storage_dtype, kernel = dequantize_gemv(
source_format, n_partition, reduce_thread, fast_decoding, trans_A, M,
trans_B, group_size, with_scaling) N,
K,
in_dtype,
out_dtype,
accum_dtype,
num_bits,
storage_dtype,
source_format,
n_partition,
reduce_thread,
fast_decoding,
trans_A,
trans_B,
group_size,
with_scaling,
)
storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) storage_nbit = int("".join(c for c in storage_dtype if c.isdigit()))
num_elems_per_byte = storage_nbit // num_bits num_elems_per_byte = storage_nbit // num_bits
A = torch.rand(M, K, dtype=getattr(torch, in_dtype)).cuda() A = torch.rand(M, K, dtype=getattr(torch, in_dtype)).cuda()
qB = torch.randint( qB = torch.randint(0, 127, (N, K // num_elems_per_byte), dtype=getattr(torch, storage_dtype)).cuda()
0, 127, (N, K // num_elems_per_byte), dtype=getattr(torch, storage_dtype)).cuda()
C = torch.zeros(M, N, dtype=getattr(torch, accum_dtype)).cuda() C = torch.zeros(M, N, dtype=getattr(torch, accum_dtype)).cuda()
if fast_decoding: if fast_decoding:
from tilelang.quantize.utils import interleave_weight from tilelang.quantize.utils import interleave_weight
qB = interleave_weight(qB, num_bits, in_dtype) qB = interleave_weight(qB, num_bits, in_dtype)
kernel(A, qB, C) kernel(A, qB, C)
# int4 reference # int4 reference
B = ( B = torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4, dtype=torch.half).to(torch.half).to(A.device)
torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4,
dtype=torch.half).to(torch.half).to(A.device))
for j in range(B.shape[1]): for j in range(B.shape[1]):
B[:, j] = ((qB[:, j // 2] >> (4 * (j % 2))) & 0xF).to(torch.half) B[:, j] = ((qB[:, j // 2] >> (4 * (j % 2))) & 0xF).to(torch.half)
......
...@@ -25,6 +25,7 @@ def get_configs(): ...@@ -25,6 +25,7 @@ def get_configs():
List[dict]: A list of configuration dictionaries covering all combinations. List[dict]: A list of configuration dictionaries covering all combinations.
""" """
import itertools import itertools
iter_params = dict( iter_params = dict(
block_M=[128], block_M=[128],
block_N=[64, 128, 256], block_N=[64, 128, 256],
...@@ -33,33 +34,33 @@ def get_configs(): ...@@ -33,33 +34,33 @@ def get_configs():
threads=[128, 256, 512], threads=[128, 256, 512],
split=[1], split=[1],
) )
return [{ return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())]
k: v for k, v in zip(iter_params, values)
} for values in itertools.product(*iter_params.values())]
@tilelang.autotune(configs=get_configs()) @tilelang.autotune(configs=get_configs())
@tilelang.jit(out_idx=[-1]) @tilelang.jit(out_idx=[-1])
def matmul(M, def matmul(
N, M,
K, N,
topk, K,
E, topk,
padding_M, E,
in_dtype, padding_M,
out_dtype, in_dtype,
accum_dtype, out_dtype,
source_format='uint', accum_dtype,
num_bits=4, source_format="uint",
scale_size=32, num_bits=4,
fast_dequant=True, scale_size=32,
with_bias=False, fast_dequant=True,
block_M=128, with_bias=False,
block_N=256, block_M=128,
block_K=128, block_N=256,
num_stages=2, block_K=128,
threads=256, num_stages=2,
split=1): threads=256,
split=1,
):
""" """
Construct and return a grouped (Mixture-of-Experts) matrix-multiply TIR kernel that multiplies A (shape MxK) by a quantized, expert-grouped B (shape ExNxQK) and writes an output of shape (M, topk, N) in out_dtype. Construct and return a grouped (Mixture-of-Experts) matrix-multiply TIR kernel that multiplies A (shape MxK) by a quantized, expert-grouped B (shape ExNxQK) and writes an output of shape (M, topk, N) in out_dtype.
...@@ -115,11 +116,12 @@ def matmul(M, ...@@ -115,11 +116,12 @@ def matmul(M,
Block_QK = block_K // num_elems_per_byte Block_QK = block_K // num_elems_per_byte
A_shared_shape = (block_M, block_K) A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, Block_QK) B_shared_shape = (block_N, Block_QK)
Bias_shared_shape = (block_N) Bias_shared_shape = block_N
B_dequantize_shared_shape = (block_N, block_K) B_dequantize_shared_shape = (block_N, block_K)
assert K % (block_K * split) == 0 assert K % (block_K * split) == 0
from tilelang.quantize import get_mxfp_intrin_group from tilelang.quantize import get_mxfp_intrin_group
# fast_dequant_bf16_fp4_twiddling # fast_dequant_bf16_fp4_twiddling
mxfp_intrin_info = get_mxfp_intrin_group( mxfp_intrin_info = get_mxfp_intrin_group(
out_dtype=in_dtype, out_dtype=in_dtype,
...@@ -221,19 +223,16 @@ def matmul(M, ...@@ -221,19 +223,16 @@ def matmul(M,
for v in T.vectorized(0, local_size): for v in T.vectorized(0, local_size):
index = i * threads * local_size + tx * local_size + v index = i * threads * local_size + tx * local_size + v
B_dequantize_shared[index // block_K, B_dequantize_shared[index // block_K, index % block_K] = B_dequantize_local_thread[v]
index % block_K] = B_dequantize_local_thread[v]
return fast_dequant_bf16_fp4_twiddling return fast_dequant_bf16_fp4_twiddling
def get_simple_dequant_func(in_dtype="fp4", out_dtype="bfloat16"): def get_simple_dequant_func(in_dtype="fp4", out_dtype="bfloat16"):
assert in_dtype in ["fp4"] assert in_dtype in ["fp4"]
assert out_dtype in ["bfloat16"] assert out_dtype in ["bfloat16"]
@T.macro @T.macro
def simple_dequant_bf16_fp4(B_shared, B_dequantize_shared, Scale_shared, k): def simple_dequant_bf16_fp4(B_shared, B_dequantize_shared, Scale_shared, k):
B_local = T.alloc_fragment(B_shared_shape, storage_dtype) B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, out_dtype) B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, out_dtype)
...@@ -244,8 +243,8 @@ def matmul(M, ...@@ -244,8 +243,8 @@ def matmul(M,
B_local[i, j // num_elems_per_byte], B_local[i, j // num_elems_per_byte],
j % num_elems_per_byte, j % num_elems_per_byte,
Scale_shared[ Scale_shared[
i, k * block_K // scale_size + j // i, k * block_K // scale_size + j // scale_size
scale_size], # Scale is the exponential part, within the representation of uint8 ], # Scale is the exponential part, within the representation of uint8
dtype=out_dtype, dtype=out_dtype,
) * T.shift_left(1, (Scale_shared[i, k * block_K // scale_size + j // scale_size])) ) * T.shift_left(1, (Scale_shared[i, k * block_K // scale_size + j // scale_size]))
T.copy(B_dequantize_local, B_dequantize_shared) T.copy(B_dequantize_local, B_dequantize_shared)
...@@ -254,19 +253,17 @@ def matmul(M, ...@@ -254,19 +253,17 @@ def matmul(M,
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, K), in_dtype), A: T.Tensor((M, K), in_dtype),
B: T.Tensor((E, N, QK), storage_dtype), B: T.Tensor((E, N, QK), storage_dtype),
Scale: T.Tensor((E, N, K // scale_size), storage_dtype), Scale: T.Tensor((E, N, K // scale_size), storage_dtype),
Bias: T.Tensor((E, N), out_dtype), Bias: T.Tensor((E, N), out_dtype),
# Add fusedmoe tensors # Add fusedmoe tensors
topk_weights: T.Tensor((M * topk), out_dtype), topk_weights: T.Tensor((M * topk), out_dtype),
sorted_token_ids: T.Tensor((padding_M), "int32"), sorted_token_ids: T.Tensor((padding_M), "int32"),
expert_ids: T.Tensor((padding_M // block_M), "int32"), expert_ids: T.Tensor((padding_M // block_M), "int32"),
C: T.Tensor((M, topk, N), out_dtype), C: T.Tensor((M, topk, N), out_dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(padding_M, block_M), threads=threads) as (bx, by):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(padding_M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, storage_dtype) B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype) B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype)
...@@ -280,17 +277,19 @@ def matmul(M, ...@@ -280,17 +277,19 @@ def matmul(M,
# May use much more shared memory than necessary # May use much more shared memory than necessary
Scale_shared = T.alloc_shared((block_N, K // scale_size), storage_dtype) Scale_shared = T.alloc_shared((block_N, K // scale_size), storage_dtype)
T.annotate_layout({ T.annotate_layout(
A_shared: tilelang.layout.make_swizzled_layout(A_shared), {
B_shared: tilelang.layout.make_swizzled_layout(B_shared), A_shared: tilelang.layout.make_swizzled_layout(A_shared),
C_shared: tilelang.layout.make_swizzled_layout(C_shared), B_shared: tilelang.layout.make_swizzled_layout(B_shared),
}) C_shared: tilelang.layout.make_swizzled_layout(C_shared),
}
)
T.use_swizzle(10) T.use_swizzle(10)
if threads == 512: if threads == 512:
T.disable_warp_group_reg_alloc() T.disable_warp_group_reg_alloc()
T.copy(sorted_token_ids[by * block_M:(by + 1) * block_M], sorted_token_ids_shared) T.copy(sorted_token_ids[by * block_M : (by + 1) * block_M], sorted_token_ids_shared)
expert_id[0] = expert_ids[by] expert_id[0] = expert_ids[by]
# Get the topk weights of each token in the current block # Get the topk weights of each token in the current block
...@@ -300,11 +299,11 @@ def matmul(M, ...@@ -300,11 +299,11 @@ def matmul(M,
# Get bias and scale based on the expert id # Get bias and scale based on the expert id
if with_bias: if with_bias:
T.copy(Bias[expert_id[0], bx * block_N:(bx + 1) * block_N], Bias_shared) T.copy(Bias[expert_id[0], bx * block_N : (bx + 1) * block_N], Bias_shared)
else: else:
T.clear(Bias_shared) T.clear(Bias_shared)
T.copy(Scale[expert_id[0], bx * block_N:(bx + 1) * block_N, :], Scale_shared) T.copy(Scale[expert_id[0], bx * block_N : (bx + 1) * block_N, :], Scale_shared)
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
C_local[i, j] = Bias_shared[j] C_local[i, j] = Bias_shared[j]
...@@ -317,14 +316,13 @@ def matmul(M, ...@@ -317,14 +316,13 @@ def matmul(M,
base = copy_i * threads * 16 + tx * 16 base = copy_i * threads * 16 + tx * 16
if sorted_token_ids_shared[base // block_K] != -1: if sorted_token_ids_shared[base // block_K] != -1:
for copy_j in T.vectorized(16): for copy_j in T.vectorized(16):
A_shared[base // block_K, base % block_K + A_shared[base // block_K, base % block_K + copy_j] = A[
copy_j] = A[sorted_token_ids_shared[base // block_K] // topk, sorted_token_ids_shared[base // block_K] // topk, k * block_K + base % block_K + copy_j
k * block_K + base % block_K + copy_j] ]
T.copy(B[expert_id[0], bx * block_N, k * block_K // num_elems_per_byte], B_shared) T.copy(B[expert_id[0], bx * block_N, k * block_K // num_elems_per_byte], B_shared)
if fast_dequant: if fast_dequant:
get_fast_dequant_twiddling_func()(B_shared, B_dequantize_shared, Scale_shared, get_fast_dequant_twiddling_func()(B_shared, B_dequantize_shared, Scale_shared, k)
k)
else: else:
get_simple_dequant_func()(B_shared, B_dequantize_shared, Scale_shared, k) get_simple_dequant_func()(B_shared, B_dequantize_shared, Scale_shared, k)
...@@ -338,10 +336,11 @@ def matmul(M, ...@@ -338,10 +336,11 @@ def matmul(M,
base = copy_i * threads * 16 + tx * 16 base = copy_i * threads * 16 + tx * 16
if sorted_token_ids_shared[base // block_N] != -1: if sorted_token_ids_shared[base // block_N] != -1:
for copy_j in T.vectorized(16): for copy_j in T.vectorized(16):
C[sorted_token_ids_shared[base // block_N] // topk, C[
sorted_token_ids_shared[base // block_N] % topk, bx * block_N + sorted_token_ids_shared[base // block_N] // topk,
base % block_N + copy_j] = C_shared[base // block_N, sorted_token_ids_shared[base // block_N] % topk,
base % block_N + copy_j] bx * block_N + base % block_N + copy_j,
] = C_shared[base // block_N, base % block_N + copy_j]
return main return main
...@@ -355,7 +354,7 @@ def ref_moe(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, bloc ...@@ -355,7 +354,7 @@ def ref_moe(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, bloc
assert scale_size == 32 # MXFP4 assert scale_size == 32 # MXFP4
# Initialize output tensor # Initialize output tensor
C = torch.ones((M, topk, N), dtype=getattr(torch, dtypeC), device='cuda') C = torch.ones((M, topk, N), dtype=getattr(torch, dtypeC), device="cuda")
# Iterate over sorted_token_ids # Iterate over sorted_token_ids
for idx in range(len(sorted_token_ids)): # padding_M for idx in range(len(sorted_token_ids)): # padding_M
...@@ -370,14 +369,11 @@ def ref_moe(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, bloc ...@@ -370,14 +369,11 @@ def ref_moe(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, bloc
# Dequantize the expert weights # Dequantize the expert weights
B = torch_convert_bit_twiddling(qB[expert_id]) # shape: (N, K) B = torch_convert_bit_twiddling(qB[expert_id]) # shape: (N, K)
B *= 2**( B *= 2 ** (Scale[expert_id][:, (torch.arange(B.shape[1], device=B.device) // scale_size)].to(torch.bfloat16))
Scale[expert_id][:, (torch.arange(B.shape[1], device=B.device) // scale_size)].to(
torch.bfloat16))
# Compute the output for this token-expert pair # Compute the output for this token-expert pair
# token_embedding @ B.T + bias # token_embedding @ B.T + bias
output = torch.matmul(token_embedding.to(torch.bfloat16), B.T.to( output = torch.matmul(token_embedding.to(torch.bfloat16), B.T.to(torch.bfloat16)) + Bias[expert_id]
torch.bfloat16)) + Bias[expert_id]
output = output.to(torch.__getattribute__(dtypeC)) output = output.to(torch.__getattribute__(dtypeC))
# Apply the topk weight # Apply the topk weight
...@@ -391,14 +387,12 @@ def ref_moe(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, bloc ...@@ -391,14 +387,12 @@ def ref_moe(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, bloc
def get_data(m, n, k, qk, scale_size, topk, E, block_M): def get_data(m, n, k, qk, scale_size, topk, E, block_M):
A = torch.empty(m, k, dtype=torch.bfloat16, device='cuda').uniform_(-1, 1) A = torch.empty(m, k, dtype=torch.bfloat16, device="cuda").uniform_(-1, 1)
qB = torch.randint( qB = torch.randint(0, 256, (E, n, qk), dtype=torch.uint8, device="cuda") # Quantized weight tensor for E experts.
0, 256, (E, n, qk), dtype=torch.uint8, Scale = torch.randint(0, 8, (E, n, k // scale_size), dtype=torch.uint8, device="cuda")
device='cuda') # Quantized weight tensor for E experts. Bias = torch.empty(E, n, dtype=torch.bfloat16, device="cuda").uniform_(-1, 1)
Scale = torch.randint(0, 8, (E, n, k // scale_size), dtype=torch.uint8, device='cuda')
Bias = torch.empty(E, n, dtype=torch.bfloat16, device='cuda').uniform_(-1, 1) weights = torch.empty(m, E, dtype=torch.bfloat16, device="cuda").uniform_(-1, 1)
weights = torch.empty(m, E, dtype=torch.bfloat16, device='cuda').uniform_(-1, 1)
# topk_weights: Router weights for the top-k experts for each token. # topk_weights: Router weights for the top-k experts for each token.
# Shape: (m, topk) # Shape: (m, topk)
# tokens_experts: A flattened tensor of expert assignments for each token. # tokens_experts: A flattened tensor of expert assignments for each token.
...@@ -420,10 +414,7 @@ def get_data(m, n, k, qk, scale_size, topk, E, block_M): ...@@ -420,10 +414,7 @@ def get_data(m, n, k, qk, scale_size, topk, E, block_M):
pad_len = ((cnt + block_M - 1) // block_M) * block_M - cnt pad_len = ((cnt + block_M - 1) // block_M) * block_M - cnt
if pad_len > 0: if pad_len > 0:
# -1 for padding (`M` instead in vLLM moe_align_block_size()) # -1 for padding (`M` instead in vLLM moe_align_block_size())
group_token_ids = torch.cat([ group_token_ids = torch.cat([group_token_ids, torch.full((pad_len,), -1, dtype=group_token_ids.dtype, device="cuda")])
group_token_ids,
torch.full((pad_len,), -1, dtype=group_token_ids.dtype, device='cuda')
])
padded_token_ids.append(group_token_ids) padded_token_ids.append(group_token_ids)
expert_ids.extend([eid] * ((cnt + block_M - 1) // block_M)) expert_ids.extend([eid] * ((cnt + block_M - 1) // block_M))
start = end start = end
...@@ -431,21 +422,13 @@ def get_data(m, n, k, qk, scale_size, topk, E, block_M): ...@@ -431,21 +422,13 @@ def get_data(m, n, k, qk, scale_size, topk, E, block_M):
# sorted_token_ids: The final flattened and padded tensor of token indices. # sorted_token_ids: The final flattened and padded tensor of token indices.
sorted_token_ids = torch.cat(padded_token_ids, dim=0).to(torch.int32) # (padding_M,) sorted_token_ids = torch.cat(padded_token_ids, dim=0).to(torch.int32) # (padding_M,)
# expert_ids: The final tensor of expert IDs corresponding to `sorted_token_ids`. # expert_ids: The final tensor of expert IDs corresponding to `sorted_token_ids`.
expert_ids = torch.tensor(expert_ids, dtype=torch.int32, device='cuda') # (padding_M,) expert_ids = torch.tensor(expert_ids, dtype=torch.int32, device="cuda") # (padding_M,)
padding_M = sorted_token_ids.shape[0] # padding_M: token number after padding padding_M = sorted_token_ids.shape[0] # padding_M: token number after padding
return A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, padding_M return A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, padding_M
def main(m=256, def main(m=256, n=256, k=256, scale_size=32, topk=4, E=32, fast_dequant=True, with_bias=False, tune=False):
n=256,
k=256,
scale_size=32,
topk=4,
E=32,
fast_dequant=True,
with_bias=False,
tune=False):
# Tunable parameters # Tunable parameters
block_M, block_N, block_K = 128, 256, 128 # noqa: F841 block_M, block_N, block_K = 128, 256, 128 # noqa: F841
num_stages = 1 # noqa: F841 num_stages = 1 # noqa: F841
...@@ -456,8 +439,7 @@ def main(m=256, ...@@ -456,8 +439,7 @@ def main(m=256,
num_bits = 4 num_bits = 4
num_elems_per_byte = 8 // num_bits num_elems_per_byte = 8 // num_bits
qk = k // num_elems_per_byte qk = k // num_elems_per_byte
A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, padding_M = get_data( A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, padding_M = get_data(m, n, k, qk, scale_size, topk, E, block_M)
m, n, k, qk, scale_size, topk, E, block_M)
if tune: if tune:
with set_autotune_inputs([A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids]): with set_autotune_inputs([A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids]):
...@@ -510,14 +492,11 @@ def main(m=256, ...@@ -510,14 +492,11 @@ def main(m=256,
expert_ids, expert_ids,
) )
print('Tilelang kernel run finished.') print("Tilelang kernel run finished.")
ref_output = ref_moe( ref_output = ref_moe(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, block_M=block_M) # Maybe a little bit slow...
A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids,
block_M=block_M) # Maybe a little bit slow...
latency = tilelang.profiler.do_bench( latency = tilelang.profiler.do_bench(lambda: kernel(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids), warmup=100)
lambda: kernel(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids), warmup=100)
print("Tilelang: {:.2f} ms".format(latency)) print("Tilelang: {:.2f} ms".format(latency))
print("Tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) print("Tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
...@@ -525,32 +504,19 @@ def main(m=256, ...@@ -525,32 +504,19 @@ def main(m=256,
max_val = diff.max() max_val = diff.max()
max_idx = diff.argmax() max_idx = diff.argmax()
print(f"max abs diff: {max_val} at index: {max_idx}") print(f"max abs diff: {max_val} at index: {max_idx}")
assert_similar( assert_similar(output, ref_output, name="output", eps=2e-5) # We care about the similarity rather than abs. difference
output, ref_output, name="output",
eps=2e-5) # We care about the similarity rather than abs. difference
print("All checks pass. ✅") print("All checks pass. ✅")
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument("--M", type=int, default=16384, help="M") # From gpt-oss-20b MoE's first gemm
"--M", type=int, default=16384, help="M") # From gpt-oss-20b MoE's first gemm
parser.add_argument("--N", type=int, default=5760, help="N") parser.add_argument("--N", type=int, default=5760, help="N")
parser.add_argument("--K", type=int, default=2944, help="K") parser.add_argument("--K", type=int, default=2944, help="K")
parser.add_argument("--scale_size", type=int, default=32, help="scale size") parser.add_argument("--scale_size", type=int, default=32, help="scale size")
parser.add_argument( parser.add_argument("--topk", type=int, default=4, help="topk") # experts activated for each token
"--topk", type=int, default=4, help="topk") # experts activated for each token
parser.add_argument("--E", type=int, default=32, help="E") # number of experts parser.add_argument("--E", type=int, default=32, help="E") # number of experts
parser.add_argument("--tune", action="store_true", help="tune configs") parser.add_argument("--tune", action="store_true", help="tune configs")
args = parser.parse_args() args = parser.parse_args()
main( main(args.M, args.N, args.K, args.scale_size, topk=args.topk, E=args.E, fast_dequant=True, with_bias=True, tune=args.tune)
args.M,
args.N,
args.K,
args.scale_size,
topk=args.topk,
E=args.E,
fast_dequant=True,
with_bias=True,
tune=args.tune)
...@@ -11,7 +11,6 @@ from utils import get_abs_err, get_err_ratio ...@@ -11,7 +11,6 @@ from utils import get_abs_err, get_err_ratio
class RegsiterLossFunction(torch.autograd.Function): class RegsiterLossFunction(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, x, loss): def forward(ctx, x, loss):
ctx.save_for_backward(loss) ctx.save_for_backward(loss)
...@@ -38,49 +37,43 @@ def ref_deepseek_sparse_attention_innner( ...@@ -38,49 +37,43 @@ def ref_deepseek_sparse_attention_innner(
index_sm_scale: Optional[float] = None, index_sm_scale: Optional[float] = None,
): ):
dtype = q.dtype dtype = q.dtype
q, kv, index_q, index_k, weights = map(lambda x: x.to(torch.float32), q, kv, index_q, index_k, weights = map(lambda x: x.to(torch.float32), (q, kv, index_q, index_k, weights))
(q, kv, index_q, index_k, weights))
index_sm_scale = index_q.shape[-1]**-0.5 index_sm_scale = index_q.shape[-1] ** -0.5
b, s = index_q.shape[:2] b, s = index_q.shape[:2]
# tl_topk_indices = tl_topk_indices.to(torch.int64) # tl_topk_indices = tl_topk_indices.to(torch.int64)
# tl_topk_indices[tl_topk_indices == -1] = s # tl_topk_indices[tl_topk_indices == -1] = s
casual_mask = (torch.arange(s)[:, None] >= torch.arange(s)[None, :]).to(q.device) casual_mask = (torch.arange(s)[:, None] >= torch.arange(s)[None, :]).to(q.device)
index_logits = einsum(index_q, index_k, 'b s1 h k, b s2 k -> b s1 h s2') index_logits = einsum(index_q, index_k, "b s1 h k, b s2 k -> b s1 h s2")
index_logits = F.relu(index_logits) index_logits = F.relu(index_logits)
index_logits = (index_logits * weights.unsqueeze(-1)).sum( index_logits = (index_logits * weights.unsqueeze(-1)).sum(dim=-2, dtype=torch.float32) * index_sm_scale
dim=-2, dtype=torch.float32) * index_sm_scale index_logits = torch.where(casual_mask, index_logits, float("-inf"))
index_logits = torch.where(casual_mask, index_logits, float('-inf'))
topk_indices = torch.topk(index_logits, k=topk, dim=-1).indices topk_indices = torch.topk(index_logits, k=topk, dim=-1).indices
topk_logits = torch.gather( topk_logits = torch.gather(F.pad(index_logits, (0, 1), value=float("-inf")), dim=-1, index=topk_indices)
F.pad(index_logits, (0, 1), value=float('-inf')), dim=-1, index=topk_indices)
topk_score = F.log_softmax(topk_logits, dim=-1, dtype=torch.float32) topk_score = F.log_softmax(topk_logits, dim=-1, dtype=torch.float32)
index_topk_score = topk_score index_topk_score = topk_score
if sm_scale is None: if sm_scale is None:
sm_scale = kv.shape[-1]**-0.5 sm_scale = kv.shape[-1] ** -0.5
h = q.shape[-2] h = q.shape[-2]
index_mask = torch.zeros((b, s, s + 1), dtype=torch.bool, device="cuda")\ index_mask = torch.zeros((b, s, s + 1), dtype=torch.bool, device="cuda").scatter_(
.scatter_(dim=-1, index=topk_indices, src=torch.ones_like(topk_indices, dtype=torch.bool))[:, :, :-1] dim=-1, index=topk_indices, src=torch.ones_like(topk_indices, dtype=torch.bool)
mask = repeat(casual_mask & index_mask, 'b s1 s2 -> b s1 h s2', h=h) )[:, :, :-1]
mask = repeat(casual_mask & index_mask, "b s1 s2 -> b s1 h s2", h=h)
k, v = kv, kv[..., :dim_v] k, v = kv, kv[..., :dim_v]
logits = einsum(q, k, 'b s1 h d, b s2 d -> b s1 h s2') * sm_scale logits = einsum(q, k, "b s1 h d, b s2 d -> b s1 h s2") * sm_scale
logits = torch.where(mask, logits, float('-inf')) logits = torch.where(mask, logits, float("-inf"))
attn_score = F.softmax(logits, dim=-1, dtype=torch.float32) attn_score = F.softmax(logits, dim=-1, dtype=torch.float32)
o = einsum(attn_score, v, 'b s1 h s2, b s2 d -> b s1 h d') o = einsum(attn_score, v, "b s1 h s2, b s2 d -> b s1 h d")
attn_score = attn_score.sum(dim=-2) # [b, s1, s2] attn_score = attn_score.sum(dim=-2) # [b, s1, s2]
attn_topk_score = torch.gather(F.pad(attn_score, (0, 1)), dim=-1, index=topk_indices) attn_topk_score = torch.gather(F.pad(attn_score, (0, 1)), dim=-1, index=topk_indices)
attn_topk_score = attn_topk_score / attn_topk_score.sum(dim=-1, keepdim=True) attn_topk_score = attn_topk_score / attn_topk_score.sum(dim=-1, keepdim=True)
loss = F.kl_div( loss = F.kl_div(index_topk_score.clip(-100, 0), attn_topk_score.detach().log().clip(-100, 0), log_target=True, reduction="sum")
index_topk_score.clip(-100, 0),
attn_topk_score.detach().log().clip(-100, 0),
log_target=True,
reduction="sum")
o = register_loss(o, loss) o = register_loss(o, loss)
return o.to(dtype), topk_indices return o.to(dtype), topk_indices
...@@ -101,11 +94,11 @@ def ref_deepseek_sparse_attention( ...@@ -101,11 +94,11 @@ def ref_deepseek_sparse_attention(
all_o, all_topk_indices = [], [] all_o, all_topk_indices = [], []
for i in range(offsets.shape[0] - 1): for i in range(offsets.shape[0] - 1):
o, topk_indices = ref_deepseek_sparse_attention_innner( o, topk_indices = ref_deepseek_sparse_attention_innner(
q[None, offsets[i]:offsets[i + 1]], q[None, offsets[i] : offsets[i + 1]],
kv[None, offsets[i]:offsets[i + 1]], kv[None, offsets[i] : offsets[i + 1]],
index_q[None, offsets[i]:offsets[i + 1]], index_q[None, offsets[i] : offsets[i + 1]],
index_k[None, offsets[i]:offsets[i + 1]], index_k[None, offsets[i] : offsets[i + 1]],
weights[None, offsets[i]:offsets[i + 1]], weights[None, offsets[i] : offsets[i + 1]],
topk, topk,
dim_v, dim_v,
sm_scale, sm_scale,
...@@ -119,7 +112,6 @@ def ref_deepseek_sparse_attention( ...@@ -119,7 +112,6 @@ def ref_deepseek_sparse_attention(
class DSAFunction(torch.autograd.Function): class DSAFunction(torch.autograd.Function):
@staticmethod @staticmethod
def forward( def forward(
ctx, ctx,
...@@ -134,12 +126,9 @@ class DSAFunction(torch.autograd.Function): ...@@ -134,12 +126,9 @@ class DSAFunction(torch.autograd.Function):
sm_scale: Optional[float] = None, sm_scale: Optional[float] = None,
): ):
# topk_indices, index_score = ref_index_score(index_q, weights, index_k, topk) # topk_indices, index_score = ref_index_score(index_q, weights, index_k, topk)
topk_indices, index_score = indexer_topk_reducesum_interface(index_q, weights, index_k, topk_indices, index_score = indexer_topk_reducesum_interface(index_q, weights, index_k, topk, offsets)
topk, offsets) o, lse = sparse_mla_fwd_interface(q, kv.unsqueeze(-2), topk_indices.unsqueeze(-2), offsets, sm_scale=sm_scale, d_v=dim_v)
o, lse = sparse_mla_fwd_interface( ctx.save_for_backward(q, kv, index_q, index_k, weights, topk_indices, index_score, o, lse, offsets)
q, kv.unsqueeze(-2), topk_indices.unsqueeze(-2), offsets, sm_scale=sm_scale, d_v=dim_v)
ctx.save_for_backward(q, kv, index_q, index_k, weights, topk_indices, index_score, o, lse,
offsets)
ctx.topk = topk ctx.topk = topk
ctx.dim_v = dim_v ctx.dim_v = dim_v
ctx.sm_scale = sm_scale ctx.sm_scale = sm_scale
...@@ -153,19 +142,10 @@ class DSAFunction(torch.autograd.Function): ...@@ -153,19 +142,10 @@ class DSAFunction(torch.autograd.Function):
): ):
q, kv, index_q, index_k, weights, topk_indices, index_score, o, lse, offsets = ctx.saved_tensors q, kv, index_q, index_k, weights, topk_indices, index_score, o, lse, offsets = ctx.saved_tensors
attn_score = sparse_mla_topk_reducesum_interface( attn_score = sparse_mla_topk_reducesum_interface(
q, kv.unsqueeze(-2), topk_indices.unsqueeze(-2), lse, offsets, q, kv.unsqueeze(-2), topk_indices.unsqueeze(-2), lse, offsets, dim_v=ctx.dim_v
dim_v=ctx.dim_v).squeeze(-2) ).squeeze(-2)
dq, dkv = sparse_mla_bwd( dq, dkv = sparse_mla_bwd(q, kv.unsqueeze(-2), o, do, topk_indices.unsqueeze(-2), lse, offsets, sm_scale=ctx.sm_scale)
q, dindex_q, dweights, dindex_k = indexer_bwd_interface(index_q, weights, index_k, attn_score, index_score, topk_indices, offsets)
kv.unsqueeze(-2),
o,
do,
topk_indices.unsqueeze(-2),
lse,
offsets,
sm_scale=ctx.sm_scale)
dindex_q, dweights, dindex_k = indexer_bwd_interface(index_q, weights, index_k, attn_score,
index_score, topk_indices, offsets)
return dq, dkv.squeeze(-2), dindex_q, dindex_k, dweights, None, None, None, None return dq, dkv.squeeze(-2), dindex_q, dindex_k, dweights, None, None, None, None
...@@ -209,8 +189,7 @@ def test_kernel( ...@@ -209,8 +189,7 @@ def test_kernel(
index_k_grad, index_k.grad = index_k.grad, None index_k_grad, index_k.grad = index_k.grad, None
weights_grad, weights.grad = weights.grad, None weights_grad, weights.grad = weights.grad, None
ref_o, ref_topk_indices = ref_deepseek_sparse_attention(q, kv, index_q, index_k, weights, ref_o, ref_topk_indices = ref_deepseek_sparse_attention(q, kv, index_q, index_k, weights, offsets, topk, D)
offsets, topk, D)
ref_o.backward(do) ref_o.backward(do)
ref_q_grad, q.grad = q.grad, None ref_q_grad, q.grad = q.grad, None
ref_kv_grad, kv.grad = kv.grad, None ref_kv_grad, kv.grad = kv.grad, None
...@@ -219,28 +198,20 @@ def test_kernel( ...@@ -219,28 +198,20 @@ def test_kernel(
ref_weights_grad, weights.grad = weights.grad, None ref_weights_grad, weights.grad = weights.grad, None
print(f"o err: {get_abs_err(o, ref_o):.6f} ratio: {get_err_ratio(o, ref_o):.6f}") print(f"o err: {get_abs_err(o, ref_o):.6f} ratio: {get_err_ratio(o, ref_o):.6f}")
print( print(f"q.grad err: {get_abs_err(q_grad, ref_q_grad):.6f} ratio: {get_err_ratio(q_grad, ref_q_grad):.6f}")
f"q.grad err: {get_abs_err(q_grad, ref_q_grad):.6f} ratio: {get_err_ratio(q_grad, ref_q_grad):.6f}" print(f"kv.grad err: {get_abs_err(kv_grad, ref_kv_grad):.6f} ratio: {get_err_ratio(kv_grad, ref_kv_grad):.6f}")
)
print(
f"kv.grad err: {get_abs_err(kv_grad, ref_kv_grad):.6f} ratio: {get_err_ratio(kv_grad, ref_kv_grad):.6f}"
)
print( print(
f"index_q.grad err: {get_abs_err(index_q_grad[:, :64, :], ref_index_q_grad[:, :64, :]):.6f} ratio: {get_err_ratio(index_q_grad[:, :64, :], ref_index_q_grad[:, :64, :]):.6f}" f"index_q.grad err: {get_abs_err(index_q_grad[:, :64, :], ref_index_q_grad[:, :64, :]):.6f} ratio: {get_err_ratio(index_q_grad[:, :64, :], ref_index_q_grad[:, :64, :]):.6f}"
) )
print( print(f"index_k.grad err: {get_abs_err(index_k_grad, ref_index_k_grad):.6f} ratio: {get_err_ratio(index_k_grad, ref_index_k_grad):.6f}")
f"index_k.grad err: {get_abs_err(index_k_grad, ref_index_k_grad):.6f} ratio: {get_err_ratio(index_k_grad, ref_index_k_grad):.6f}" print(f"weights.grad err: {get_abs_err(weights_grad, ref_weights_grad):.6f} ratio: {get_err_ratio(weights_grad, ref_weights_grad):.6f}")
)
print(
f"weights.grad err: {get_abs_err(weights_grad, ref_weights_grad):.6f} ratio: {get_err_ratio(weights_grad, ref_weights_grad):.6f}"
)
intersections = [] intersections = []
for j in range(S): for j in range(S):
ref_np = ref_topk_indices[j].cpu().to(torch.int32).numpy() ref_np = ref_topk_indices[j].cpu().to(torch.int32).numpy()
trt_np = topk_indices[j].cpu().to(torch.int32).numpy() trt_np = topk_indices[j].cpu().to(torch.int32).numpy()
mask = (trt_np != -1) mask = trt_np != -1
set_ref = set(ref_np[mask]) set_ref = set(ref_np[mask])
set_trt = set(trt_np[mask]) set_trt = set(trt_np[mask])
......
...@@ -5,7 +5,9 @@ import functools ...@@ -5,7 +5,9 @@ import functools
from typing import Callable, Any from typing import Callable, Any
def tensor_cache(fn: Callable[..., torch.Tensor],) -> Callable[..., torch.Tensor]: def tensor_cache(
fn: Callable[..., torch.Tensor],
) -> Callable[..., torch.Tensor]:
""" """
A decorator that caches the most recent result of a function with tensor inputs. A decorator that caches the most recent result of a function with tensor inputs.
...@@ -29,10 +31,12 @@ def tensor_cache(fn: Callable[..., torch.Tensor],) -> Callable[..., torch.Tensor ...@@ -29,10 +31,12 @@ def tensor_cache(fn: Callable[..., torch.Tensor],) -> Callable[..., torch.Tensor
def wrapper(*args: Any, **kwargs: Any) -> Any: def wrapper(*args: Any, **kwargs: Any) -> Any:
nonlocal last_args, last_kwargs, last_result nonlocal last_args, last_kwargs, last_result
if (last_args is not None and last_kwargs is not None) and \ if (
(len(args) == len(last_args) and len(kwargs) == len(last_kwargs)) and \ (last_args is not None and last_kwargs is not None)
all(a is b for a, b in zip(args, last_args, strict=False)) and \ and (len(args) == len(last_args) and len(kwargs) == len(last_kwargs))
all(k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items()): and all(a is b for a, b in zip(args, last_args, strict=False))
and all(k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items())
):
return last_result return last_result
result = fn(*args, **kwargs) result = fn(*args, **kwargs)
...@@ -56,16 +60,15 @@ def prepare_cu_seqlens_from_lens( ...@@ -56,16 +60,15 @@ def prepare_cu_seqlens_from_lens(
@tensor_cache @tensor_cache
def prepare_lens_from_cu_seqlens(cu_seqlens: torch.LongTensor,) -> torch.LongTensor: def prepare_lens_from_cu_seqlens(
cu_seqlens: torch.LongTensor,
) -> torch.LongTensor:
return torch.diff(cu_seqlens) return torch.diff(cu_seqlens)
@tensor_cache @tensor_cache
def prepare_position_ids(cu_seqlens: torch.LongTensor) -> torch.LongTensor: def prepare_position_ids(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
return torch.cat([ return torch.cat([torch.arange(n, dtype=cu_seqlens.dtype, device=cu_seqlens.device) for n in prepare_lens(cu_seqlens).unbind()])
torch.arange(n, dtype=cu_seqlens.dtype, device=cu_seqlens.device)
for n in prepare_lens(cu_seqlens).unbind()
])
@tensor_cache @tensor_cache
......
...@@ -49,17 +49,17 @@ def tl_indexer_bwd_impl( ...@@ -49,17 +49,17 @@ def tl_indexer_bwd_impl(
@T.prim_func @T.prim_func
def tl_indexer_bwd_kernel( def tl_indexer_bwd_kernel(
IndexQ: T.Tensor(index_q_shape, dtype), IndexQ: T.Tensor(index_q_shape, dtype),
Weights: T.Tensor(weights_shape, dtype), Weights: T.Tensor(weights_shape, dtype),
IndexK: T.Tensor(index_k_shape, dtype), IndexK: T.Tensor(index_k_shape, dtype),
dIndexQ: T.Tensor(index_q_shape, dtype), dIndexQ: T.Tensor(index_q_shape, dtype),
dWeights: T.Tensor(weights_shape, dtype), dWeights: T.Tensor(weights_shape, dtype),
dIndexK: T.Tensor(index_k_shape, dtype), dIndexK: T.Tensor(index_k_shape, dtype),
AttnScore: T.Tensor(shape_p, FP32), AttnScore: T.Tensor(shape_p, FP32),
IndexScore: T.Tensor(shape_p, FP32), IndexScore: T.Tensor(shape_p, FP32),
TopkIndices: T.Tensor(topk_indices_shape, INT32), TopkIndices: T.Tensor(topk_indices_shape, INT32),
Offsets: T.Tensor(offsets_shape, INT32), Offsets: T.Tensor(offsets_shape, INT32),
TokenIndices: T.Tensor(token_indices_shape, INT32), TokenIndices: T.Tensor(token_indices_shape, INT32),
): ):
with T.Kernel(seq_len, threads=num_threads) as (bx): with T.Kernel(seq_len, threads=num_threads) as (bx):
i_b, i_t = TokenIndices[bx, 0], TokenIndices[bx, 1] i_b, i_t = TokenIndices[bx, 0], TokenIndices[bx, 1]
...@@ -81,7 +81,6 @@ def tl_indexer_bwd_impl( ...@@ -81,7 +81,6 @@ def tl_indexer_bwd_impl(
index_q_shared[i, j] = index_q_shared[i, j] * sm_scale index_q_shared[i, j] = index_q_shared[i, j] * sm_scale
for bi_i in T.Pipelined(num_blocks, num_stages=num_stages): for bi_i in T.Pipelined(num_blocks, num_stages=num_stages):
i_st = bi_i * block_I i_st = bi_i * block_I
i_ed = (bi_i + 1) * block_I i_ed = (bi_i + 1) * block_I
...@@ -91,8 +90,7 @@ def tl_indexer_bwd_impl( ...@@ -91,8 +90,7 @@ def tl_indexer_bwd_impl(
index_k_shared = T.alloc_shared([block_I, dim], dtype=dtype) index_k_shared = T.alloc_shared([block_I, dim], dtype=dtype)
for i, j in T.Parallel(block_I, dim): for i, j in T.Parallel(block_I, dim):
pos = indices_shared[i] pos = indices_shared[i]
index_k_shared[i, j] = T.if_then_else((pos > -1) & (pos <= i_t), index_k_shared[i, j] = T.if_then_else((pos > -1) & (pos <= i_t), IndexK[bos + pos, j], 0)
IndexK[bos + pos, j], 0)
attn_score_shared = T.alloc_shared([block_I], dtype=accum_dtype) attn_score_shared = T.alloc_shared([block_I], dtype=accum_dtype)
index_score_shared = T.alloc_shared([block_I], dtype=accum_dtype) index_score_shared = T.alloc_shared([block_I], dtype=accum_dtype)
...@@ -115,8 +113,7 @@ def tl_indexer_bwd_impl( ...@@ -115,8 +113,7 @@ def tl_indexer_bwd_impl(
# dw # dw
d_weights_i = T.alloc_fragment((block_I, heads), accum_dtype) d_weights_i = T.alloc_fragment((block_I, heads), accum_dtype)
for i, j in T.Parallel(block_I, heads): for i, j in T.Parallel(block_I, heads):
d_weights_i[i, d_weights_i[i, j] = (index_score_shared[i] - attn_score_shared[i]) * logits[i, j]
j] = (index_score_shared[i] - attn_score_shared[i]) * logits[i, j]
T.reduce_sum(d_weights_i, d_weights_frag, dim=0, clear=False) T.reduce_sum(d_weights_i, d_weights_frag, dim=0, clear=False)
d_logits_qk = T.alloc_shared((block_I, heads), accum_dtype) d_logits_qk = T.alloc_shared((block_I, heads), accum_dtype)
...@@ -129,8 +126,7 @@ def tl_indexer_bwd_impl( ...@@ -129,8 +126,7 @@ def tl_indexer_bwd_impl(
d_relu = 1.0 d_relu = 1.0
else: else:
d_relu = 0.0 d_relu = 0.0
d_logits_qk[i, j] = (index_score_shared[i] - d_logits_qk[i, j] = (index_score_shared[i] - attn_score_shared[i]) * d_relu * weights_shared[j]
attn_score_shared[i]) * d_relu * weights_shared[j]
# dq # dq
T.copy(d_logits_qk, d_logits_qk_cast1) T.copy(d_logits_qk, d_logits_qk_cast1)
...@@ -157,7 +153,7 @@ def tl_indexer_bwd_impl( ...@@ -157,7 +153,7 @@ def tl_indexer_bwd_impl(
for i, j in T.Parallel(block_I, dim): for i, j in T.Parallel(block_I, dim):
pos = indices_shared[i] pos = indices_shared[i]
if ((pos > -1) & (pos <= i_t)): if (pos > -1) & (pos <= i_t):
T.atomic_add(dIndexK[bos + pos, j], d_index_k_frag[i, j]) T.atomic_add(dIndexK[bos + pos, j], d_index_k_frag[i, j])
for i, j in T.Parallel(heads, dim): for i, j in T.Parallel(heads, dim):
...@@ -184,40 +180,35 @@ def indexer_bwd_interface( ...@@ -184,40 +180,35 @@ def indexer_bwd_interface(
dweights = torch.zeros_like(weights) dweights = torch.zeros_like(weights)
dk = torch.zeros_like(k) dk = torch.zeros_like(k)
kernel = tl_indexer_bwd_impl(heads, dim, topk) kernel = tl_indexer_bwd_impl(heads, dim, topk)
kernel(q, weights, k, dq, dweights, dk, attn_score, index_score, topk_indices, offsets, kernel(q, weights, k, dq, dweights, dk, attn_score, index_score, topk_indices, offsets, token_indices)
token_indices)
return dq, dweights, dk return dq, dweights, dk
def ref_indexer_bwd(Q: torch.Tensor, Weights: torch.Tensor, K: torch.Tensor, def ref_indexer_bwd(
TopkIndices: torch.Tensor, AttnScore: torch.Tensor, Q: torch.Tensor, Weights: torch.Tensor, K: torch.Tensor, TopkIndices: torch.Tensor, AttnScore: torch.Tensor, offsets: torch.Tensor
offsets: torch.Tensor) -> torch.Tensor: ) -> torch.Tensor:
Q.requires_grad_(True) Q.requires_grad_(True)
Weights.requires_grad_(True) Weights.requires_grad_(True)
K.requires_grad_(True) K.requires_grad_(True)
softmax_scale = Q.shape[-1]**-0.5 softmax_scale = Q.shape[-1] ** -0.5
all_loss = [] all_loss = []
all_log_topk_prob = [] all_log_topk_prob = []
for i in range(offsets.shape[0] - 1): for i in range(offsets.shape[0] - 1):
assert (offsets[i + 1] - offsets[i]).item() >= TopkIndices.shape[-1] assert (offsets[i + 1] - offsets[i]).item() >= TopkIndices.shape[-1]
q = Q[offsets[i]:offsets[i + 1]] q = Q[offsets[i] : offsets[i + 1]]
weights = Weights[offsets[i]:offsets[i + 1]] weights = Weights[offsets[i] : offsets[i + 1]]
k = K[offsets[i]:offsets[i + 1]] k = K[offsets[i] : offsets[i + 1]]
topk_indices = TopkIndices[offsets[i]:offsets[i + 1]] topk_indices = TopkIndices[offsets[i] : offsets[i + 1]]
attn_score = AttnScore[offsets[i]:offsets[i + 1]] attn_score = AttnScore[offsets[i] : offsets[i + 1]]
s = q.shape[0] s = q.shape[0]
mask = (torch.arange(s)[:, None] >= torch.arange(s)[None, :]).to(q.device) mask = (torch.arange(s)[:, None] >= torch.arange(s)[None, :]).to(q.device)
logits = einsum(q, k, 's1 h k, s2 k -> s1 h s2') * softmax_scale logits = einsum(q, k, "s1 h k, s2 k -> s1 h s2") * softmax_scale
logits = F.relu(logits) logits = F.relu(logits)
score = (logits * weights.unsqueeze(-1)).sum(dim=-2, dtype=torch.float32) score = (logits * weights.unsqueeze(-1)).sum(dim=-2, dtype=torch.float32)
score = torch.where(mask, score, float('-inf')) score = torch.where(mask, score, float("-inf"))
topk_value = torch.gather(score, dim=-1, index=topk_indices.to(torch.int64)) topk_value = torch.gather(score, dim=-1, index=topk_indices.to(torch.int64))
log_topk_prob = F.log_softmax(topk_value, dim=-1, dtype=torch.float32) log_topk_prob = F.log_softmax(topk_value, dim=-1, dtype=torch.float32)
loss = F.kl_div( loss = F.kl_div(log_topk_prob.clip(-100, 0), attn_score.log().clip(-100, 0), log_target=True, reduction="sum")
log_topk_prob.clip(-100, 0),
attn_score.log().clip(-100, 0),
log_target=True,
reduction="sum")
all_loss.append(loss) all_loss.append(loss)
all_log_topk_prob.append(log_topk_prob) all_log_topk_prob.append(log_topk_prob)
loss = torch.stack(all_loss).sum() loss = torch.stack(all_loss).sum()
...@@ -244,15 +235,13 @@ def test_kernel( ...@@ -244,15 +235,13 @@ def test_kernel(
seq_len = (offsets[i + 1] - offsets[i]).item() seq_len = (offsets[i + 1] - offsets[i]).item()
mask = (torch.arange(seq_len)[:, None] >= torch.arange(topk)[None, :]).to(q.device) mask = (torch.arange(seq_len)[:, None] >= torch.arange(topk)[None, :]).to(q.device)
logits = torch.ones(seq_len, topk).cuda() logits = torch.ones(seq_len, topk).cuda()
logits = torch.where(mask, logits, float('-inf')) logits = torch.where(mask, logits, float("-inf"))
attn_score = F.softmax(logits, dim=-1, dtype=torch.float32) attn_score = F.softmax(logits, dim=-1, dtype=torch.float32)
all_attn_score.append(attn_score) all_attn_score.append(attn_score)
attn_score = torch.cat(all_attn_score, dim=0) attn_score = torch.cat(all_attn_score, dim=0)
topk_indices = repeat( topk_indices = repeat(torch.arange(topk, dtype=torch.int32).cuda(), "k -> s k", s=S).contiguous()
torch.arange(topk, dtype=torch.int32).cuda(), 'k -> s k', s=S).contiguous() index_score, ref_dq, ref_dw, ref_dk = ref_indexer_bwd(q, w, k, topk_indices, attn_score, offsets)
index_score, ref_dq, ref_dw, ref_dk = ref_indexer_bwd(q, w, k, topk_indices, attn_score,
offsets)
dq, dw, dk = indexer_bwd_interface(q, w, k, attn_score, index_score, topk_indices, offsets) dq, dw, dk = indexer_bwd_interface(q, w, k, attn_score, index_score, topk_indices, offsets)
...@@ -261,5 +250,5 @@ def test_kernel( ...@@ -261,5 +250,5 @@ def test_kernel(
print(f"dq err: {get_abs_err(dk, ref_dk):.6f} ratio: {get_err_ratio(dk, ref_dk):.6f}") print(f"dq err: {get_abs_err(dk, ref_dk):.6f} ratio: {get_err_ratio(dk, ref_dk):.6f}")
if __name__ == '__main__': if __name__ == "__main__":
test_kernel() test_kernel()
...@@ -53,8 +53,8 @@ def tl_indexer_topk_reducesum_impl( ...@@ -53,8 +53,8 @@ def tl_indexer_topk_reducesum_impl(
@T.macro @T.macro
def bitonic_sort( def bitonic_sort(
topk_index_shared: T.SharedBuffer([N], dtype=INT32), topk_index_shared: T.SharedBuffer([N], dtype=INT32),
topk_value_shared: T.SharedBuffer([N], dtype=FP32), topk_value_shared: T.SharedBuffer([N], dtype=FP32),
): ):
T.sync_threads() T.sync_threads()
for i1 in T.serial(num_iters): for i1 in T.serial(num_iters):
...@@ -62,9 +62,10 @@ def tl_indexer_topk_reducesum_impl( ...@@ -62,9 +62,10 @@ def tl_indexer_topk_reducesum_impl(
for i in T.Parallel(N): for i in T.Parallel(N):
ascending = (i & (1 << (i1 + 1))) != 0 ascending = (i & (1 << (i1 + 1))) != 0
j = i ^ (1 << (i1 - i2)) j = i ^ (1 << (i1 - i2))
if i < j and \ if i < j and (
((ascending and topk_value_shared[i] > topk_value_shared[j]) or ( (ascending and topk_value_shared[i] > topk_value_shared[j])
not ascending and topk_value_shared[i] < topk_value_shared[j])): or (not ascending and topk_value_shared[i] < topk_value_shared[j])
):
val = topk_value_shared[i] val = topk_value_shared[i]
topk_value_shared[i] = topk_value_shared[j] topk_value_shared[i] = topk_value_shared[j]
topk_value_shared[j] = val topk_value_shared[j] = val
...@@ -75,13 +76,13 @@ def tl_indexer_topk_reducesum_impl( ...@@ -75,13 +76,13 @@ def tl_indexer_topk_reducesum_impl(
@T.prim_func @T.prim_func
def tl_indexer_topk_reducesum_kernel( def tl_indexer_topk_reducesum_kernel(
IndexQ: T.Tensor(index_q_shape, dtype), IndexQ: T.Tensor(index_q_shape, dtype),
Weights: T.Tensor(weights_shape, dtype), Weights: T.Tensor(weights_shape, dtype),
IndexK: T.Tensor(index_k_shape, dtype), IndexK: T.Tensor(index_k_shape, dtype),
TopkIndices: T.Tensor(topk_indices_shape, INT32), TopkIndices: T.Tensor(topk_indices_shape, INT32),
ReduceSum: T.Tensor(topk_indices_shape, FP32), ReduceSum: T.Tensor(topk_indices_shape, FP32),
Offsets: T.Tensor(offsets_shape, INT32), Offsets: T.Tensor(offsets_shape, INT32),
TokenIndices: T.Tensor(token_indices_shape, INT32), TokenIndices: T.Tensor(token_indices_shape, INT32),
): ):
with T.Kernel(seq_len, threads=num_threads) as (bx): with T.Kernel(seq_len, threads=num_threads) as (bx):
i_b, i_t = TokenIndices[bx, 0], TokenIndices[bx, 1] i_b, i_t = TokenIndices[bx, 0], TokenIndices[bx, 1]
...@@ -92,7 +93,7 @@ def tl_indexer_topk_reducesum_impl( ...@@ -92,7 +93,7 @@ def tl_indexer_topk_reducesum_impl(
topk_value_shared = T.alloc_shared([N], dtype=FP32) topk_value_shared = T.alloc_shared([N], dtype=FP32)
T.fill(topk_index_shared, -1) T.fill(topk_index_shared, -1)
T.fill(topk_value_shared, float('-inf')) T.fill(topk_value_shared, float("-inf"))
T.sync_threads() T.sync_threads()
index_q_shared = T.alloc_shared([heads, dim], dtype=dtype) index_q_shared = T.alloc_shared([heads, dim], dtype=dtype)
...@@ -113,8 +114,7 @@ def tl_indexer_topk_reducesum_impl( ...@@ -113,8 +114,7 @@ def tl_indexer_topk_reducesum_impl(
index_k_shared = T.alloc_shared([block_K, dim], dtype=dtype) index_k_shared = T.alloc_shared([block_K, dim], dtype=dtype)
for i, j in T.Parallel(block_K, dim): for i, j in T.Parallel(block_K, dim):
index_k_shared[i, j] = T.if_then_else(k_st + i < k_ed, IndexK[bos + k_st + i, index_k_shared[i, j] = T.if_then_else(k_st + i < k_ed, IndexK[bos + k_st + i, j], 0)
j], 0)
T.sync_threads() T.sync_threads()
logits = T.alloc_fragment((block_K, heads), FP32) logits = T.alloc_fragment((block_K, heads), FP32)
...@@ -144,7 +144,7 @@ def tl_indexer_topk_reducesum_impl( ...@@ -144,7 +144,7 @@ def tl_indexer_topk_reducesum_impl(
T.sync_threads() T.sync_threads()
for i in T.Parallel(block_K): for i in T.Parallel(block_K):
if k_st + i > i_t: if k_st + i > i_t:
logits_sum[i] = float('-inf') logits_sum[i] = float("-inf")
j = offset + i j = offset + i
topk_index_shared[j] = k_st + i topk_index_shared[j] = k_st + i
topk_value_shared[j] = logits_sum[i] topk_value_shared[j] = logits_sum[i]
...@@ -209,22 +209,21 @@ def indexer_topk_reducesum_interface( ...@@ -209,22 +209,21 @@ def indexer_topk_reducesum_interface(
return topk_indices, topk_score return topk_indices, topk_score
def ref_index_score(Q: torch.Tensor, Weights: torch.Tensor, K: torch.Tensor, topk: int, def ref_index_score(Q: torch.Tensor, Weights: torch.Tensor, K: torch.Tensor, topk: int, offsets: torch.Tensor) -> torch.Tensor:
offsets: torch.Tensor) -> torch.Tensor:
all_topk_indices = [] all_topk_indices = []
all_topk_score = [] all_topk_score = []
for i in range(offsets.shape[0] - 1): for i in range(offsets.shape[0] - 1):
assert (offsets[i + 1] - offsets[i]).item() >= topk assert (offsets[i + 1] - offsets[i]).item() >= topk
q = Q[offsets[i]:offsets[i + 1]] q = Q[offsets[i] : offsets[i + 1]]
weights = Weights[offsets[i]:offsets[i + 1]] weights = Weights[offsets[i] : offsets[i + 1]]
k = K[offsets[i]:offsets[i + 1]] k = K[offsets[i] : offsets[i + 1]]
softmax_scale = q.shape[-1]**-0.5 softmax_scale = q.shape[-1] ** -0.5
s = q.shape[0] s = q.shape[0]
mask = (torch.arange(s)[:, None] >= torch.arange(s)[None, :]).to(q.device) mask = (torch.arange(s)[:, None] >= torch.arange(s)[None, :]).to(q.device)
logits = einsum(q, k, 's1 h k, s2 k -> s1 h s2') logits = einsum(q, k, "s1 h k, s2 k -> s1 h s2")
logits = F.relu(logits) logits = F.relu(logits)
logits = (logits * weights.unsqueeze(-1)).sum(dim=-2, dtype=torch.float32) * softmax_scale logits = (logits * weights.unsqueeze(-1)).sum(dim=-2, dtype=torch.float32) * softmax_scale
logits = torch.where(mask, logits, float('-inf')) logits = torch.where(mask, logits, float("-inf"))
topk_logits, topk_indices = torch.topk(logits, k=topk, dim=-1) topk_logits, topk_indices = torch.topk(logits, k=topk, dim=-1)
topk_score = F.softmax(topk_logits, dim=-1, dtype=torch.float32) topk_score = F.softmax(topk_logits, dim=-1, dtype=torch.float32)
all_topk_indices.append(topk_indices) all_topk_indices.append(topk_indices)
...@@ -265,13 +264,10 @@ def test_kernel( ...@@ -265,13 +264,10 @@ def test_kernel(
set_trt = set(trt_np[mask]) set_trt = set(trt_np[mask])
intersection = set_ref & set_trt intersection = set_ref & set_trt
print("idx:", j, "selected/all:", len(intersection), "/", len(set_ref), "=", print("idx:", j, "selected/all:", len(intersection), "/", len(set_ref), "=", len(intersection) / len(set_ref))
len(intersection) / len(set_ref))
print( print(f"err: {get_abs_err(ref_np_val, trt_np_val):.6f} ratio: {get_err_ratio(ref_np_val, trt_np_val):.6f}")
f"err: {get_abs_err(ref_np_val, trt_np_val):.6f} ratio: {get_err_ratio(ref_np_val, trt_np_val):.6f}"
)
if __name__ == '__main__': if __name__ == "__main__":
test_kernel() test_kernel()
...@@ -19,15 +19,15 @@ def preprocess( ...@@ -19,15 +19,15 @@ def preprocess(
assert dtype == "bfloat16" assert dtype == "bfloat16"
assert accum_dtype == "float" assert accum_dtype == "float"
S = T.symbolic('S') S = T.symbolic("S")
shape = [S, H, D] shape = [S, H, D]
@T.prim_func @T.prim_func
def preprocess_kernel( def preprocess_kernel(
O: T.Tensor(shape, dtype), O: T.Tensor(shape, dtype),
dO: T.Tensor(shape, dtype), dO: T.Tensor(shape, dtype),
Delta: T.Tensor([S, H], accum_dtype), Delta: T.Tensor([S, H], accum_dtype),
): ):
with T.Kernel(H, T.ceildiv(S, block_ND)) as (bx, by): with T.Kernel(H, T.ceildiv(S, block_ND)) as (bx, by):
o = T.alloc_fragment([block_ND, block_ND], accum_dtype) o = T.alloc_fragment([block_ND, block_ND], accum_dtype)
...@@ -36,13 +36,12 @@ def preprocess( ...@@ -36,13 +36,12 @@ def preprocess(
acc = T.alloc_fragment([block_ND, block_ND], accum_dtype) acc = T.alloc_fragment([block_ND, block_ND], accum_dtype)
T.clear(acc) T.clear(acc)
for k in T.Pipelined(T.ceildiv(D, block_ND), num_stages=num_stages): for k in T.Pipelined(T.ceildiv(D, block_ND), num_stages=num_stages):
T.copy(O[by * block_ND:(by + 1) * block_ND, bx, k * block_ND:(k + 1) * block_ND], o) T.copy(O[by * block_ND : (by + 1) * block_ND, bx, k * block_ND : (k + 1) * block_ND], o)
T.copy(dO[by * block_ND:(by + 1) * block_ND, bx, k * block_ND:(k + 1) * block_ND], T.copy(dO[by * block_ND : (by + 1) * block_ND, bx, k * block_ND : (k + 1) * block_ND], do)
do)
for i, j in T.Parallel(block_ND, block_ND): for i, j in T.Parallel(block_ND, block_ND):
acc[i, j] += o[i, j] * do[i, j] acc[i, j] += o[i, j] * do[i, j]
T.reduce_sum(acc, delta, 1) T.reduce_sum(acc, delta, 1)
T.copy(delta, Delta[by * block_ND:(by + 1) * block_ND, bx]) T.copy(delta, Delta[by * block_ND : (by + 1) * block_ND, bx])
return preprocess_kernel return preprocess_kernel
...@@ -59,19 +58,19 @@ def postprocess( ...@@ -59,19 +58,19 @@ def postprocess(
): ):
assert dtype == "bfloat16" assert dtype == "bfloat16"
assert accum_dtype == "float" assert accum_dtype == "float"
S_kv = T.symbolic('S_kv') S_kv = T.symbolic("S_kv")
dkv_shape = [S_kv, kv_group, D + D_tail] dkv_shape = [S_kv, kv_group, D + D_tail]
@T.prim_func @T.prim_func
def postprocess_kernel( def postprocess_kernel(
dKV: T.Tensor(dkv_shape, accum_dtype), dKV: T.Tensor(dkv_shape, accum_dtype),
dKV_out: T.Tensor(dkv_shape, dtype), dKV_out: T.Tensor(dkv_shape, dtype),
): ):
with T.Kernel(T.ceildiv(S_kv, block_N), kv_group, threads=threads) as (bx, by): with T.Kernel(T.ceildiv(S_kv, block_N), kv_group, threads=threads) as (bx, by):
T.copy( T.copy(
dKV[bx * block_N:(bx + 1) * block_N, by, :], dKV[bx * block_N : (bx + 1) * block_N, by, :],
dKV_out[bx * block_N:(bx + 1) * block_N, by, :], dKV_out[bx * block_N : (bx + 1) * block_N, by, :],
) )
return postprocess_kernel return postprocess_kernel
...@@ -82,7 +81,8 @@ def postprocess( ...@@ -82,7 +81,8 @@ def postprocess(
pass_configs={ pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
}) },
)
def bwd( def bwd(
H, H,
D, D,
...@@ -98,17 +98,17 @@ def bwd( ...@@ -98,17 +98,17 @@ def bwd(
dtype="bfloat16", dtype="bfloat16",
accum_dtype="float", accum_dtype="float",
): ):
assert is_causal == True, 'non-casual is not supported now' assert is_causal == True, "non-casual is not supported now"
assert topk % block_size == 0, 'otherwise will load some index=0 thus causing wrong kv to be loaded' assert topk % block_size == 0, "otherwise will load some index=0 thus causing wrong kv to be loaded"
assert dtype == "bfloat16" assert dtype == "bfloat16"
assert accum_dtype == "float" assert accum_dtype == "float"
assert indices_dtype == "int32" assert indices_dtype == "int32"
if sm_scale is None: if sm_scale is None:
sm_scale = (D + D_tail)**(-0.5) sm_scale = (D + D_tail) ** (-0.5)
B_plus_one = T.symbolic('B_plus_one') B_plus_one = T.symbolic("B_plus_one")
S = T.symbolic('S') S = T.symbolic("S")
H_kv = H // kv_group H_kv = H // kv_group
q_shape = [S, H, D + D_tail] q_shape = [S, H, D + D_tail]
...@@ -132,16 +132,16 @@ def bwd( ...@@ -132,16 +132,16 @@ def bwd(
@T.prim_func @T.prim_func
def sparse_mla_bwd_kernel( def sparse_mla_bwd_kernel(
Q: T.Tensor(q_shape, dtype), Q: T.Tensor(q_shape, dtype),
KV: T.Tensor(k_shape, dtype), KV: T.Tensor(k_shape, dtype),
dO: T.Tensor(o_shape, dtype), dO: T.Tensor(o_shape, dtype),
Indices: T.Tensor(indices_shape, indices_dtype), Indices: T.Tensor(indices_shape, indices_dtype),
Lse: T.Tensor(lse_shape, accum_dtype), Lse: T.Tensor(lse_shape, accum_dtype),
Delta: T.Tensor(delta_shape, accum_dtype), Delta: T.Tensor(delta_shape, accum_dtype),
Offsets: T.Tensor(offsets_shape, indices_dtype), Offsets: T.Tensor(offsets_shape, indices_dtype),
TokenIndices: T.Tensor(token_indices_shape, indices_dtype), TokenIndices: T.Tensor(token_indices_shape, indices_dtype),
dQ: T.Tensor(q_shape, dtype), dQ: T.Tensor(q_shape, dtype),
dKV: T.Tensor(k_shape, accum_dtype), dKV: T.Tensor(k_shape, accum_dtype),
): ):
with T.Kernel(S, kv_group, threads=threads) as (b_s_i, bz): with T.Kernel(S, kv_group, threads=threads) as (b_s_i, bz):
Q_shared = T.alloc_shared([padded_H, D], dtype) Q_shared = T.alloc_shared([padded_H, D], dtype)
...@@ -163,32 +163,32 @@ def bwd( ...@@ -163,32 +163,32 @@ def bwd(
acc_dkv = T.alloc_fragment([BS, D], accum_dtype) acc_dkv = T.alloc_fragment([BS, D], accum_dtype)
acc_dkv_tail = T.alloc_fragment([BS, D_tail], accum_dtype) acc_dkv_tail = T.alloc_fragment([BS, D_tail], accum_dtype)
acc_dkv_shared = T.view(KV_shared, shape=[BS // split_store, D], dtype=accum_dtype) acc_dkv_shared = T.view(KV_shared, shape=[BS // split_store, D], dtype=accum_dtype)
acc_dkv_tail_shared = T.view( acc_dkv_tail_shared = T.view(KV_tail_shared, shape=[BS // split_store, D_tail], dtype=accum_dtype)
KV_tail_shared, shape=[BS // split_store, D_tail], dtype=accum_dtype)
b_i, s_i = TokenIndices[b_s_i, 0], TokenIndices[b_s_i, 1] b_i, s_i = TokenIndices[b_s_i, 0], TokenIndices[b_s_i, 1]
bos, eos = Offsets[b_i], Offsets[b_i + 1] bos, eos = Offsets[b_i], Offsets[b_i + 1]
max_kv_i = s_i max_kv_i = s_i
T.copy(Q[bos + s_i, bz * padded_H:(bz + 1) * padded_H, :D], Q_shared) T.copy(Q[bos + s_i, bz * padded_H : (bz + 1) * padded_H, :D], Q_shared)
T.copy(Q[bos + s_i, bz * padded_H:(bz + 1) * padded_H, D:], Q_tail_shared) T.copy(Q[bos + s_i, bz * padded_H : (bz + 1) * padded_H, D:], Q_tail_shared)
T.copy(dO[bos + s_i, bz * padded_H:(bz + 1) * padded_H, :D], dO_shared) T.copy(dO[bos + s_i, bz * padded_H : (bz + 1) * padded_H, :D], dO_shared)
T.clear(acc_dq) T.clear(acc_dq)
T.clear(acc_dq_tail) T.clear(acc_dq_tail)
T.annotate_layout({ T.annotate_layout(
dQ_shared: tilelang.layout.make_swizzled_layout(dQ_shared), {
dQ_tail_shared: tilelang.layout.make_swizzled_layout(dQ_tail_shared), dQ_shared: tilelang.layout.make_swizzled_layout(dQ_shared),
}) dQ_tail_shared: tilelang.layout.make_swizzled_layout(dQ_tail_shared),
}
)
# Process each block of indices # Process each block of indices
for i_i in T.Pipelined(NS, num_stages=num_stages): for i_i in T.Pipelined(NS, num_stages=num_stages):
# Check which indices are valid # Check which indices are valid
for bi_i in T.Parallel(BS): for bi_i in T.Parallel(BS):
mask[bi_i] = (Indices[bos + s_i, bz, i_i * BS + bi_i] <= max_kv_i) & ( mask[bi_i] = (Indices[bos + s_i, bz, i_i * BS + bi_i] <= max_kv_i) & (Indices[bos + s_i, bz, i_i * BS + bi_i] != -1)
Indices[bos + s_i, bz, i_i * BS + bi_i] != -1)
# Compute attention scores # Compute attention scores
for h_i, bi_i in T.Parallel(padded_H, BS): for h_i, bi_i in T.Parallel(padded_H, BS):
...@@ -196,65 +196,33 @@ def bwd( ...@@ -196,65 +196,33 @@ def bwd(
# Load KV, V for this block of indices # Load KV, V for this block of indices
for bi_i, d_i in T.Parallel(BS, D): for bi_i, d_i in T.Parallel(BS, D):
KV_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i], bz, KV_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i], bz, d_i]
d_i]
T.gemm( T.gemm(Q_shared, KV_shared, acc_p, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
Q_shared, KV_shared, acc_p, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
for bi_i, d_i in T.Parallel(BS, D_tail): for bi_i, d_i in T.Parallel(BS, D_tail):
KV_tail_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i], KV_tail_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i], bz, D + d_i]
bz, D + d_i] T.gemm(Q_tail_shared, KV_tail_shared, acc_p, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
T.gemm(
Q_tail_shared,
KV_tail_shared,
acc_p,
transpose_B=True,
policy=T.GemmWarpPolicy.FullCol)
for h_i, bi_i in T.Parallel(padded_H, BS): for h_i, bi_i in T.Parallel(padded_H, BS):
acc_p[h_i, bi_i] = T.exp(acc_p[h_i, bi_i] * sm_scale - acc_p[h_i, bi_i] = T.exp(acc_p[h_i, bi_i] * sm_scale - Lse[bos + s_i, bz * padded_H + h_i])
Lse[bos + s_i, bz * padded_H + h_i])
T.copy(acc_p, P_shared_cast) T.copy(acc_p, P_shared_cast)
T.gemm( T.gemm(dO_shared, KV_shared, acc_dp, transpose_B=True, policy=T.GemmWarpPolicy.FullCol, clear_accum=True)
dO_shared,
KV_shared,
acc_dp,
transpose_B=True,
policy=T.GemmWarpPolicy.FullCol,
clear_accum=True)
for h_i, bi_i in T.Parallel(padded_H, BS): for h_i, bi_i in T.Parallel(padded_H, BS):
acc_dp[h_i, bi_i] = acc_p[h_i, bi_i] * ( acc_dp[h_i, bi_i] = acc_p[h_i, bi_i] * (acc_dp[h_i, bi_i] - Delta[bos + s_i, bz * padded_H + h_i]) * sm_scale
acc_dp[h_i, bi_i] - Delta[bos + s_i, bz * padded_H + h_i]) * sm_scale
T.copy(acc_dp, dP_shared_cast) T.copy(acc_dp, dP_shared_cast)
T.gemm(dP_shared_cast, KV_shared, acc_dq, policy=T.GemmWarpPolicy.FullCol) T.gemm(dP_shared_cast, KV_shared, acc_dq, policy=T.GemmWarpPolicy.FullCol)
T.gemm(dP_shared_cast, KV_tail_shared, acc_dq_tail, policy=T.GemmWarpPolicy.FullCol) T.gemm(dP_shared_cast, KV_tail_shared, acc_dq_tail, policy=T.GemmWarpPolicy.FullCol)
T.gemm( T.gemm(dP_shared_cast, Q_shared, acc_dkv, transpose_A=True, policy=T.GemmWarpPolicy.FullCol, clear_accum=True)
dP_shared_cast, T.gemm(P_shared_cast, dO_shared, acc_dkv, transpose_A=True, policy=T.GemmWarpPolicy.FullCol)
Q_shared,
acc_dkv,
transpose_A=True,
policy=T.GemmWarpPolicy.FullCol,
clear_accum=True)
T.gemm(
P_shared_cast,
dO_shared,
acc_dkv,
transpose_A=True,
policy=T.GemmWarpPolicy.FullCol)
T.clear(acc_dkv_tail) T.clear(acc_dkv_tail)
T.gemm( T.gemm(dP_shared_cast, Q_tail_shared, acc_dkv_tail, transpose_A=True, policy=T.GemmWarpPolicy.FullCol)
dP_shared_cast,
Q_tail_shared,
acc_dkv_tail,
transpose_A=True,
policy=T.GemmWarpPolicy.FullCol)
for s in range(split_store): for s in range(split_store):
for bi_i, d_i in T.Parallel(BS, D): for bi_i, d_i in T.Parallel(BS, D):
...@@ -263,44 +231,32 @@ def bwd( ...@@ -263,44 +231,32 @@ def bwd(
for bi_i, d_i in T.Parallel(BS, D_tail): for bi_i, d_i in T.Parallel(BS, D_tail):
if bi_i < BS // split_store: if bi_i < BS // split_store:
acc_dkv_tail_shared[bi_i, acc_dkv_tail_shared[bi_i, d_i] = acc_dkv_tail[bi_i + s * (BS // split_store), d_i]
d_i] = acc_dkv_tail[bi_i + s * (BS // split_store),
d_i]
for bi_i, d_i in T.Parallel(BS // split_store, D // 4): for bi_i, d_i in T.Parallel(BS // split_store, D // 4):
T.atomic_addx4( T.atomic_addx4(
dKV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i + s * dKV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i + s * (BS // split_store)], bz, d_i * 4],
(BS // split_store)], bz, d_i * 4], acc_dkv_shared[bi_i, d_i * 4],
acc_dkv_shared[bi_i, d_i * 4]) )
# Atomically update dKV, dKV_tail tensors # Atomically update dKV, dKV_tail tensors
for bi_i, d_i in T.Parallel(BS // split_store, D_tail // 4): for bi_i, d_i in T.Parallel(BS // split_store, D_tail // 4):
T.atomic_addx4( T.atomic_addx4(
dKV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i + s * dKV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i + s * (BS // split_store)], bz, D + d_i * 4],
(BS // split_store)], bz, D + d_i * 4], acc_dkv_tail_shared[bi_i, d_i * 4],
acc_dkv_tail_shared[bi_i, d_i * 4]) )
# Store the accumulated dQ # Store the accumulated dQ
T.copy(acc_dq, dQ_shared) T.copy(acc_dq, dQ_shared)
T.copy(acc_dq_tail, dQ_tail_shared) T.copy(acc_dq_tail, dQ_tail_shared)
T.copy(dQ_shared, dQ[bos + s_i, bz * padded_H:(bz + 1) * padded_H, :D]) T.copy(dQ_shared, dQ[bos + s_i, bz * padded_H : (bz + 1) * padded_H, :D])
T.copy(dQ_tail_shared, dQ[bos + s_i, bz * padded_H:(bz + 1) * padded_H, D:]) T.copy(dQ_tail_shared, dQ[bos + s_i, bz * padded_H : (bz + 1) * padded_H, D:])
return sparse_mla_bwd_kernel return sparse_mla_bwd_kernel
def sparse_mla_bwd(q, def sparse_mla_bwd(q, kv, o, do, indices, lse, offsets, sm_scale=None, is_casual=True, return_kernel=False, delta=None):
kv,
o,
do,
indices,
lse,
offsets,
sm_scale=None,
is_casual=True,
return_kernel=False,
delta=None):
assert q.is_contiguous() assert q.is_contiguous()
assert kv.is_contiguous() assert kv.is_contiguous()
assert indices.is_contiguous() assert indices.is_contiguous()
...@@ -333,16 +289,9 @@ def sparse_mla_bwd(q, ...@@ -333,16 +289,9 @@ def sparse_mla_bwd(q,
return dq, dkv return dq, dkv
def ref_sparse_mla_bwd_interface(q, def ref_sparse_mla_bwd_interface(q, kv, o, do, indices, lse, offsets, sm_scale=None, is_casual=True):
kv,
o,
do,
indices,
lse,
offsets,
sm_scale=None,
is_casual=True):
from sparse_mla_fwd import ref_sparse_mla_fwd_interface from sparse_mla_fwd import ref_sparse_mla_fwd_interface
q = q.detach().clone() q = q.detach().clone()
kv = kv.detach().clone() kv = kv.detach().clone()
q.requires_grad = True q.requires_grad = True
...@@ -352,32 +301,25 @@ def ref_sparse_mla_bwd_interface(q, ...@@ -352,32 +301,25 @@ def ref_sparse_mla_bwd_interface(q,
return q.grad, kv.grad return q.grad, kv.grad
def test_sparse_mla_bwd(B=1, def test_sparse_mla_bwd(B=1, S=2048, H=64, HKV=1, DQKV=576, DV=512, topk=512, dtype=torch.bfloat16, check_correctness=True):
S=2048,
H=64,
HKV=1,
DQKV=576,
DV=512,
topk=512,
dtype=torch.bfloat16,
check_correctness=True):
# Prepare data # Prepare data
q = torch.randn((S, H, DQKV), dtype=dtype, device='cuda').requires_grad_(True) q = torch.randn((S, H, DQKV), dtype=dtype, device="cuda").requires_grad_(True)
kv = torch.randn((S, HKV, DQKV), dtype=dtype, device='cuda').requires_grad_(True) kv = torch.randn((S, HKV, DQKV), dtype=dtype, device="cuda").requires_grad_(True)
do = torch.randn((S, H, DV), dtype=dtype, device='cuda') do = torch.randn((S, H, DV), dtype=dtype, device="cuda")
offsets = torch.tensor([0, S], dtype=torch.int32, device="cuda") offsets = torch.tensor([0, S], dtype=torch.int32, device="cuda")
indices = torch.full((S, HKV, topk), S, dtype=torch.int32, device='cuda') indices = torch.full((S, HKV, topk), S, dtype=torch.int32, device="cuda")
for i in range(offsets.shape[0] - 1): for i in range(offsets.shape[0] - 1):
seq_len = (offsets[i + 1] - offsets[i]).item() seq_len = (offsets[i + 1] - offsets[i]).item()
assert seq_len >= topk assert seq_len >= topk
for t in range(seq_len): for t in range(seq_len):
for h in range(HKV): for h in range(HKV):
i_i = torch.randperm(max(1, t))[:topk] i_i = torch.randperm(max(1, t))[:topk]
indices[offsets[i] + t, h, :len(i_i)] = i_i indices[offsets[i] + t, h, : len(i_i)] = i_i
# Forward # Forward
from sparse_mla_fwd import sparse_mla_fwd_interface from sparse_mla_fwd import sparse_mla_fwd_interface
tl_out, tl_lse = sparse_mla_fwd_interface(q, kv, indices, offsets) tl_out, tl_lse = sparse_mla_fwd_interface(q, kv, indices, offsets)
tl_dq, tl_dkv = sparse_mla_bwd(q, kv, tl_out, do, indices, tl_lse, offsets) tl_dq, tl_dkv = sparse_mla_bwd(q, kv, tl_out, do, indices, tl_lse, offsets)
...@@ -388,13 +330,15 @@ def test_sparse_mla_bwd(B=1, ...@@ -388,13 +330,15 @@ def test_sparse_mla_bwd(B=1,
assert_tensors_similar(tl_dkv, ref_dkv, eps=1e-4, name="dkv") assert_tensors_similar(tl_dkv, ref_dkv, eps=1e-4, name="dkv")
print("assert_tensors_similar passed") print("assert_tensors_similar passed")
per_token_flop = 2 * sum([ per_token_flop = 2 * sum(
H * DV * topk, [
H * DQKV * topk, H * DV * topk,
H * DQKV * topk, H * DQKV * topk,
H * DQKV * topk, H * DQKV * topk,
H * DV * topk, H * DQKV * topk,
]) H * DV * topk,
]
)
from tilelang.profiler import do_bench from tilelang.profiler import do_bench
def fn(): def fn():
...@@ -402,19 +346,9 @@ def test_sparse_mla_bwd(B=1, ...@@ -402,19 +346,9 @@ def test_sparse_mla_bwd(B=1,
ms = do_bench(fn, rep=100, warmup=250) ms = do_bench(fn, rep=100, warmup=250)
print(f"Average time: {ms:.3f} ms") print(f"Average time: {ms:.3f} ms")
print(f'bwd io bandwidth = ', print(f"bwd io bandwidth = ", (B * S * max(DQKV * 2, DQKV + DV) * topk * 2) / (ms * 1e-3) / 1e12)
(B * S * max(DQKV * 2, DQKV + DV) * topk * 2) / (ms * 1e-3) / 1e12) print(f"bwd tflops = ", per_token_flop * S / (ms * 1e-3) / 1e12)
print(f'bwd tflops = ', per_token_flop * S / (ms * 1e-3) / 1e12)
if __name__ == "__main__": if __name__ == "__main__":
test_sparse_mla_bwd( test_sparse_mla_bwd(B=1, S=2048, H=64, HKV=1, DQKV=576, DV=512, topk=512, dtype=torch.bfloat16, check_correctness=True)
B=1,
S=2048,
H=64,
HKV=1,
DQKV=576,
DV=512,
topk=512,
dtype=torch.bfloat16,
check_correctness=True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment