"git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "003e2eb9fa7a431b1600de0b72462c55334f71f6"
Unverified Commit 29051439 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Lint] Phaseout Yapf format and embrace ruff format (#1417)

parent e84b24bc
......@@ -22,9 +22,9 @@ def preprocess(
@T.prim_func
def preprocess_kernel(
O: T.Tensor(shape, dtype),
dO: T.Tensor(shape, dtype),
Delta: T.Tensor([B, S, H], accum_dtype),
O: T.Tensor(shape, dtype),
dO: T.Tensor(shape, dtype),
Delta: T.Tensor([B, S, H], accum_dtype),
):
with T.Kernel(H, T.ceildiv(S, block_ND), B) as (bx, by, bz):
o = T.alloc_fragment([block_ND, block_ND], accum_dtype)
......@@ -33,16 +33,12 @@ def preprocess(
acc = T.alloc_fragment([block_ND, block_ND], accum_dtype)
T.clear(acc)
for k in T.Pipelined(T.ceildiv(D, block_ND), num_stages=num_stages):
T.copy(
O[bz, by * block_ND:(by + 1) * block_ND, bx, k * block_ND:(k + 1) * block_ND],
o)
T.copy(
dO[bz, by * block_ND:(by + 1) * block_ND, bx, k * block_ND:(k + 1) * block_ND],
do)
T.copy(O[bz, by * block_ND : (by + 1) * block_ND, bx, k * block_ND : (k + 1) * block_ND], o)
T.copy(dO[bz, by * block_ND : (by + 1) * block_ND, bx, k * block_ND : (k + 1) * block_ND], do)
for i, j in T.Parallel(block_ND, block_ND):
acc[i, j] += o[i, j] * do[i, j]
T.reduce_sum(acc, delta, 1)
T.copy(delta, Delta[bz, by * block_ND:(by + 1) * block_ND, bx])
T.copy(delta, Delta[bz, by * block_ND : (by + 1) * block_ND, bx])
return preprocess_kernel
......@@ -65,13 +61,13 @@ def postprocess(
@T.prim_func
def postprocess_kernel(
dKV: T.Tensor(dkv_shape, accum_dtype),
dKV_out: T.Tensor(dkv_shape, dtype),
dKV: T.Tensor(dkv_shape, accum_dtype),
dKV_out: T.Tensor(dkv_shape, dtype),
):
with T.Kernel(T.ceildiv(S_kv, block_N), kv_group, B, threads=threads) as (bx, by, bz):
T.copy(
dKV[bz, bx * block_N:(bx + 1) * block_N, by, :],
dKV_out[bz, bx * block_N:(bx + 1) * block_N, by, :],
dKV[bz, bx * block_N : (bx + 1) * block_N, by, :],
dKV_out[bz, bx * block_N : (bx + 1) * block_N, by, :],
)
return postprocess_kernel
......@@ -83,7 +79,8 @@ def postprocess(
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE: True,
})
},
)
def bwd(
B,
S,
......@@ -102,14 +99,14 @@ def bwd(
dtype="bfloat16",
accum_dtype="float",
):
assert is_causal == True, 'non-casual is not supported now'
assert topk % block_size == 0, 'otherwise will load some index=0 thus causing wrong kv to be loaded'
assert is_causal == True, "non-casual is not supported now"
assert topk % block_size == 0, "otherwise will load some index=0 thus causing wrong kv to be loaded"
assert dtype == "bfloat16"
assert accum_dtype == "float"
assert indices_dtype == "int32"
if sm_scale is None:
sm_scale = (D + D_tail)**(-0.5)
sm_scale = (D + D_tail) ** (-0.5)
sm_scale_mul_reciprocal_log2 = sm_scale * 1.44269504 # log2(e)
H_kv = H // kv_group
......@@ -132,14 +129,14 @@ def bwd(
@T.prim_func
def sparse_mla_bwd_kernel(
Q: T.Tensor(q_shape, dtype),
KV: T.Tensor(k_shape, dtype),
dO: T.Tensor(o_shape, dtype),
Indices: T.Tensor(indices_shape, indices_dtype),
Lse: T.Tensor(lse_shape, accum_dtype),
Delta: T.Tensor(delta_shape, accum_dtype),
dQ: T.Tensor(q_shape, dtype),
dKV: T.Tensor(k_shape, accum_dtype),
Q: T.Tensor(q_shape, dtype),
KV: T.Tensor(k_shape, dtype),
dO: T.Tensor(o_shape, dtype),
Indices: T.Tensor(indices_shape, indices_dtype),
Lse: T.Tensor(lse_shape, accum_dtype),
Delta: T.Tensor(delta_shape, accum_dtype),
dQ: T.Tensor(q_shape, dtype),
dKV: T.Tensor(k_shape, accum_dtype),
):
with T.Kernel(S, B, kv_group, threads=threads) as (s_i, by, bz):
Q_shared = T.alloc_shared([padded_H, D], dtype)
......@@ -165,17 +162,19 @@ def bwd(
max_kv_i = s_i
T.copy(Q[by, s_i, bz * padded_H:(bz + 1) * padded_H, :D], Q_shared)
T.copy(Q[by, s_i, bz * padded_H:(bz + 1) * padded_H, D:], Q_tail_shared)
T.copy(dO[by, s_i, bz * padded_H:(bz + 1) * padded_H, :D], dO_shared)
T.copy(Q[by, s_i, bz * padded_H : (bz + 1) * padded_H, :D], Q_shared)
T.copy(Q[by, s_i, bz * padded_H : (bz + 1) * padded_H, D:], Q_tail_shared)
T.copy(dO[by, s_i, bz * padded_H : (bz + 1) * padded_H, :D], dO_shared)
T.clear(acc_dq)
T.clear(acc_dq_tail)
T.annotate_layout({
dQ_shared: tilelang.layout.make_swizzled_layout(dQ_shared),
dQ_tail_shared: tilelang.layout.make_swizzled_layout(dQ_tail_shared),
})
T.annotate_layout(
{
dQ_shared: tilelang.layout.make_swizzled_layout(dQ_shared),
dQ_tail_shared: tilelang.layout.make_swizzled_layout(dQ_tail_shared),
}
)
# Process each block of indices
for i_i in T.Pipelined(NS, num_stages=num_stages):
......@@ -191,62 +190,31 @@ def bwd(
for bi_i, d_i in T.Parallel(BS, D):
KV_shared[bi_i, d_i] = KV[by, Indices[by, s_i, bz, i_i * BS + bi_i], bz, d_i]
T.gemm(
Q_shared, KV_shared, acc_p, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
T.gemm(Q_shared, KV_shared, acc_p, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
for bi_i, d_i in T.Parallel(BS, D_tail):
KV_tail_shared[bi_i, d_i] = KV[by, Indices[by, s_i, bz, i_i * BS + bi_i], bz,
D + d_i]
T.gemm(
Q_tail_shared,
KV_tail_shared,
acc_p,
transpose_B=True,
policy=T.GemmWarpPolicy.FullCol)
KV_tail_shared[bi_i, d_i] = KV[by, Indices[by, s_i, bz, i_i * BS + bi_i], bz, D + d_i]
T.gemm(Q_tail_shared, KV_tail_shared, acc_p, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
for h_i, bi_i in T.Parallel(padded_H, BS):
acc_p[h_i, bi_i] = T.exp2(acc_p[h_i, bi_i] * sm_scale_mul_reciprocal_log2 -
Lse[by, s_i, bz * padded_H + h_i])
acc_p[h_i, bi_i] = T.exp2(acc_p[h_i, bi_i] * sm_scale_mul_reciprocal_log2 - Lse[by, s_i, bz * padded_H + h_i])
T.copy(acc_p, P_shared_cast)
T.gemm(
dO_shared,
KV_shared,
acc_dp,
transpose_B=True,
policy=T.GemmWarpPolicy.FullCol,
clear_accum=True)
T.gemm(dO_shared, KV_shared, acc_dp, transpose_B=True, policy=T.GemmWarpPolicy.FullCol, clear_accum=True)
for h_i, bi_i in T.Parallel(padded_H, BS):
acc_dp[h_i, bi_i] = acc_p[h_i, bi_i] * (
acc_dp[h_i, bi_i] - Delta[by, s_i, bz * padded_H + h_i]) * sm_scale
acc_dp[h_i, bi_i] = acc_p[h_i, bi_i] * (acc_dp[h_i, bi_i] - Delta[by, s_i, bz * padded_H + h_i]) * sm_scale
T.copy(acc_dp, dP_shared_cast)
T.gemm(dP_shared_cast, KV_shared, acc_dq, policy=T.GemmWarpPolicy.FullCol)
T.gemm(dP_shared_cast, KV_tail_shared, acc_dq_tail, policy=T.GemmWarpPolicy.FullCol)
T.gemm(
dP_shared_cast,
Q_shared,
acc_dkv,
transpose_A=True,
policy=T.GemmWarpPolicy.FullCol,
clear_accum=True)
T.gemm(
P_shared_cast,
dO_shared,
acc_dkv,
transpose_A=True,
policy=T.GemmWarpPolicy.FullCol)
T.gemm(dP_shared_cast, Q_shared, acc_dkv, transpose_A=True, policy=T.GemmWarpPolicy.FullCol, clear_accum=True)
T.gemm(P_shared_cast, dO_shared, acc_dkv, transpose_A=True, policy=T.GemmWarpPolicy.FullCol)
T.clear(acc_dkv_tail)
T.gemm(
dP_shared_cast,
Q_tail_shared,
acc_dkv_tail,
transpose_A=True,
policy=T.GemmWarpPolicy.FullCol)
T.gemm(dP_shared_cast, Q_tail_shared, acc_dkv_tail, transpose_A=True, policy=T.GemmWarpPolicy.FullCol)
for s in range(split_store):
for bi_i, d_i in T.Parallel(BS, D):
......@@ -255,41 +223,32 @@ def bwd(
for bi_i, d_i in T.Parallel(BS, D_tail):
if bi_i < BS // split_store:
acc_dkv_tail_shared[bi_i,
d_i] = acc_dkv_tail[bi_i + s * (BS // split_store),
d_i]
acc_dkv_tail_shared[bi_i, d_i] = acc_dkv_tail[bi_i + s * (BS // split_store), d_i]
for bi_i, d_i in T.Parallel(BS // split_store, D // 4):
T.atomic_addx4(
dKV[by, Indices[by, s_i, bz, i_i * BS + bi_i + s * (BS // split_store)],
bz, d_i * 4], acc_dkv_shared[bi_i, d_i * 4])
dKV[by, Indices[by, s_i, bz, i_i * BS + bi_i + s * (BS // split_store)], bz, d_i * 4],
acc_dkv_shared[bi_i, d_i * 4],
)
# Atomically update dKV, dKV_tail tensors
for bi_i, d_i in T.Parallel(BS // split_store, D_tail // 4):
T.atomic_addx4(
dKV[by, Indices[by, s_i, bz, i_i * BS + bi_i + s * (BS // split_store)],
bz, D + d_i * 4], acc_dkv_tail_shared[bi_i, d_i * 4])
dKV[by, Indices[by, s_i, bz, i_i * BS + bi_i + s * (BS // split_store)], bz, D + d_i * 4],
acc_dkv_tail_shared[bi_i, d_i * 4],
)
# Store the accumulated dQ
T.copy(acc_dq, dQ_shared)
T.copy(acc_dq_tail, dQ_tail_shared)
T.copy(dQ_shared, dQ[by, s_i, bz * padded_H:(bz + 1) * padded_H, :D])
T.copy(dQ_tail_shared, dQ[by, s_i, bz * padded_H:(bz + 1) * padded_H, D:])
T.copy(dQ_shared, dQ[by, s_i, bz * padded_H : (bz + 1) * padded_H, :D])
T.copy(dQ_tail_shared, dQ[by, s_i, bz * padded_H : (bz + 1) * padded_H, D:])
return sparse_mla_bwd_kernel
def sparse_mla_bwd(q,
kv,
o,
do,
indices,
lse,
sm_scale=None,
is_casual=True,
return_kernel=False,
delta=None):
def sparse_mla_bwd(q, kv, o, do, indices, lse, sm_scale=None, is_casual=True, return_kernel=False, delta=None):
assert q.is_contiguous()
assert kv.is_contiguous()
assert indices.is_contiguous()
......@@ -322,6 +281,7 @@ def sparse_mla_bwd(q,
def ref_sparse_mla_bwd_interface(q, kv, o, do, indices, lse, sm_scale=None, is_casual=True):
from sparse_mla_fwd import ref_sparse_mla_fwd_interface
q = q.detach().clone()
kv = kv.detach().clone()
q.requires_grad = True
......@@ -331,30 +291,22 @@ def ref_sparse_mla_bwd_interface(q, kv, o, do, indices, lse, sm_scale=None, is_c
return q.grad, kv.grad
def test_sparse_mla_bwd(B=1,
S=4096,
SKV=8192,
H=64,
HKV=1,
DQKV=576,
DV=512,
topk=2048,
dtype=torch.bfloat16,
check_correctness=True):
def test_sparse_mla_bwd(B=1, S=4096, SKV=8192, H=64, HKV=1, DQKV=576, DV=512, topk=2048, dtype=torch.bfloat16, check_correctness=True):
# Prepare data
q = torch.randn((B, S, H, DQKV), dtype=dtype, device='cuda').requires_grad_(True)
kv = torch.randn((B, SKV, HKV, DQKV), dtype=dtype, device='cuda').requires_grad_(True)
do = torch.randn((B, S, H, DV), dtype=dtype, device='cuda')
q = torch.randn((B, S, H, DQKV), dtype=dtype, device="cuda").requires_grad_(True)
kv = torch.randn((B, SKV, HKV, DQKV), dtype=dtype, device="cuda").requires_grad_(True)
do = torch.randn((B, S, H, DV), dtype=dtype, device="cuda")
indices = torch.full((B, S, HKV, topk), SKV, dtype=torch.int32, device='cuda')
indices = torch.full((B, S, HKV, topk), SKV, dtype=torch.int32, device="cuda")
for b in range(B):
for t in range(S):
for h in range(HKV):
i_i = torch.randperm(max(1, t))[:topk]
indices[b, t, h, :len(i_i)] = i_i
indices[b, t, h, : len(i_i)] = i_i
# Forward
from sparse_mla_fwd import sparse_mla_fwd_interface
tl_out, tl_lse = sparse_mla_fwd_interface(q, kv, indices)
tl_dq, tl_dkv = sparse_mla_bwd(q, kv, tl_out, do, indices, tl_lse)
......@@ -365,13 +317,15 @@ def test_sparse_mla_bwd(B=1,
assert_tensors_similar(tl_dkv, ref_dkv, eps=1e-4, name="dkv")
print("assert_tensors_similar passed")
per_token_flop = 2 * sum([
H * DV * topk,
H * DQKV * topk,
H * DQKV * topk,
H * DQKV * topk,
H * DV * topk,
])
per_token_flop = 2 * sum(
[
H * DV * topk,
H * DQKV * topk,
H * DQKV * topk,
H * DQKV * topk,
H * DV * topk,
]
)
from tilelang.profiler import do_bench
def fn():
......@@ -379,20 +333,9 @@ def test_sparse_mla_bwd(B=1,
ms = do_bench(fn, rep=100, warmup=250)
print(f"Average time: {ms:.3f} ms")
print(f'bwd io bandwidth = ',
(B * S * max(DQKV * 2, DQKV + DV) * topk * 2) / (ms * 1e-3) / 1e12)
print(f'bwd tflops = ', per_token_flop * S / (ms * 1e-3) / 1e12)
print(f"bwd io bandwidth = ", (B * S * max(DQKV * 2, DQKV + DV) * topk * 2) / (ms * 1e-3) / 1e12)
print(f"bwd tflops = ", per_token_flop * S / (ms * 1e-3) / 1e12)
if __name__ == "__main__":
test_sparse_mla_bwd(
B=1,
S=4096,
SKV=8192,
H=64,
HKV=1,
DQKV=576,
DV=512,
topk=2048,
dtype=torch.bfloat16,
check_correctness=True)
test_sparse_mla_bwd(B=1, S=4096, SKV=8192, H=64, HKV=1, DQKV=576, DV=512, topk=2048, dtype=torch.bfloat16, check_correctness=True)
......@@ -25,15 +25,12 @@ def sparse_mla_fwd(
num_stages=2,
threads=256,
):
assert dim == tilelang.math.next_power_of_2(
dim), f"haven't check padding correctness yet, dim={dim}"
assert tail_dim == tilelang.math.next_power_of_2(
tail_dim), f"haven't check padding correctness yet, dim={tail_dim}"
assert dim == tilelang.math.next_power_of_2(dim), f"haven't check padding correctness yet, dim={dim}"
assert tail_dim == tilelang.math.next_power_of_2(tail_dim), f"haven't check padding correctness yet, dim={tail_dim}"
assert is_causal == True, "non-casual is not supported"
assert (topk %
block_I == 0), "otherwise will load some index=0 thus causing wrong kv to be loaded"
assert topk % block_I == 0, "otherwise will load some index=0 thus causing wrong kv to be loaded"
if sm_scale is None:
sm_scale = (1.0 / (dim + tail_dim))**0.5 * 1.44269504 # log2(e)
sm_scale = (1.0 / (dim + tail_dim)) ** 0.5 * 1.44269504 # log2(e)
else:
sm_scale = sm_scale * 1.44269504 # log2(e)
......@@ -55,9 +52,9 @@ def sparse_mla_fwd(
H = head_kv
padded_H = max(tilelang.math.next_power_of_2(head_kv), 16)
if padded_H != H:
assert (
kv_group == 1
), "here we solve the H padding automatically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automatically)"
assert kv_group == 1, (
"here we solve the H padding automatically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automatically)"
)
BI = block_I
NI = tilelang.cdiv(topk, block_I)
D = dim
......@@ -73,18 +70,17 @@ def sparse_mla_fwd(
@T.prim_func
def main(
Q: T.Tensor(q_shape, dtype), # type: ignore
KV: T.Tensor(kv_shape, dtype), # type: ignore
Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore
Output: T.Tensor(o_shape, dtype), # type: ignore
Lse: T.Tensor(lse_shape, accum_dtype), # type: ignore
Q: T.Tensor(q_shape, dtype), # type: ignore
KV: T.Tensor(kv_shape, dtype), # type: ignore
Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore
Output: T.Tensor(o_shape, dtype), # type: ignore
Lse: T.Tensor(lse_shape, accum_dtype), # type: ignore
):
with T.Kernel(
seq_len * REPLICATE_H, batch, kv_group, threads=threads) as (
bx,
by,
bz,
):
with T.Kernel(seq_len * REPLICATE_H, batch, kv_group, threads=threads) as (
bx,
by,
bz,
):
Q_shared = T.alloc_shared([H_per_block, D], dtype)
Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype)
KV_shared = T.alloc_shared([BI, D], dtype)
......@@ -118,16 +114,13 @@ def sparse_mla_fwd(
T.copy(Q[b_i, s_i, H0:H1, D:], Q_tail_shared)
for i_i in T.Pipelined(NI, num_stages=num_stages):
for bi_i in T.Parallel(BI):
mask[bi_i] = Indices[b_i, s_i, g_i, i_i * BI + bi_i] <= max_kv_i
for bi_i, d_i in T.Parallel(BI, D):
KV_shared[bi_i, d_i] = KV[b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i,
d_i]
KV_shared[bi_i, d_i] = KV[b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, d_i]
for bi_i, d_i in T.Parallel(BI, D_tail):
K_tail_shared[bi_i, d_i] = KV[b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i,
D + d_i]
K_tail_shared[bi_i, d_i] = KV[b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, D + d_i]
for h_i, bi_i in T.Parallel(H_per_block, BI):
acc_s[h_i, bi_i] = T.if_then_else(mask[bi_i], 0, -T.infinity(acc_s.dtype))
......@@ -176,15 +169,7 @@ def sparse_mla_fwd(
return main
def sparse_mla_fwd_interface(q,
kv,
indices,
sm_scale=None,
return_p_sum: bool = False,
d_v=512,
block_I=64,
num_stages=2,
threads=256):
def sparse_mla_fwd_interface(q, kv, indices, sm_scale=None, return_p_sum: bool = False, d_v=512, block_I=64, num_stages=2, threads=256):
is_casual = True
assert return_p_sum == False, "This kernel file is for fwd only"
assert q.is_contiguous() and kv.is_contiguous() and indices.is_contiguous()
......@@ -201,16 +186,8 @@ def sparse_mla_fwd_interface(q,
assert indices.shape == (batch, seq_len, kv_group, topk)
kernel = sparse_mla_fwd(
heads,
dim,
tail_dim,
topk,
kv_group,
sm_scale,
is_casual,
block_I=block_I,
num_stages=num_stages,
threads=threads)
heads, dim, tail_dim, topk, kv_group, sm_scale, is_casual, block_I=block_I, num_stages=num_stages, threads=threads
)
out, lse = kernel(q, kv, indices)
return out, lse
......@@ -230,14 +207,14 @@ def ref_sparse_mla_fwd_interface(q, kv, indices, sm_scale=None, is_casual=True):
b, _, _, dim_v = v.shape
g_index = g
h_index = h // g
compressed_casual_mask = torch.arange(
0, sq, dtype=torch.int32, device="cuda").view(-1, 1) >= torch.arange(
1 - 1, sk * 1, 1, dtype=torch.int32, device="cuda").view(1, -1)
compressed_casual_mask = torch.arange(0, sq, dtype=torch.int32, device="cuda").view(-1, 1) >= torch.arange(
1 - 1, sk * 1, 1, dtype=torch.int32, device="cuda"
).view(1, -1)
mask = q.new_zeros(b, g_index, sq, sk + 1, dtype=torch.bool).scatter(3, indices.long(), 1)
mask = mask[..., :-1]
mask = mask & compressed_casual_mask.view(1, 1, sq, sk)
mask[:, :, :1 - 1, 0] = True
mask[:, :, : 1 - 1, 0] = True
mask = mask.view(b, g_index, 1, sq, sk)
q = q.view(b, sq, g, -1, dim_q)
......@@ -252,19 +229,21 @@ def ref_sparse_mla_fwd_interface(q, kv, indices, sm_scale=None, is_casual=True):
return o.to(torch.bfloat16)
def test_sparse_mla_fwd(B=1,
S=4096,
SKV=8192,
H=128,
HKV=1,
DQK=576,
DV=512,
topk=2048,
dtype=torch.bfloat16,
check_correctness=True,
block_I=64,
num_stages=2,
threads=256):
def test_sparse_mla_fwd(
B=1,
S=4096,
SKV=8192,
H=128,
HKV=1,
DQK=576,
DV=512,
topk=2048,
dtype=torch.bfloat16,
check_correctness=True,
block_I=64,
num_stages=2,
threads=256,
):
torch.random.manual_seed(0)
q = torch.randn((B, S, H, DQK), dtype=dtype, device="cuda").requires_grad_(True)
kv = torch.randn((B, SKV, HKV, DQK), dtype=dtype, device="cuda").requires_grad_(True)
......@@ -274,10 +253,9 @@ def test_sparse_mla_fwd(B=1,
for t in range(S):
for h in range(HKV):
i_i = torch.randperm(max(1, t))[:topk]
indices[b, t, h, :len(i_i)] = i_i
indices[b, t, h, : len(i_i)] = i_i
tl_out, tl_lse = sparse_mla_fwd_interface(
q, kv, indices, block_I=block_I, num_stages=num_stages, threads=threads)
tl_out, tl_lse = sparse_mla_fwd_interface(q, kv, indices, block_I=block_I, num_stages=num_stages, threads=threads)
if check_correctness:
# otherwise may cause out of memory
......@@ -286,8 +264,7 @@ def test_sparse_mla_fwd(B=1,
print("assert_tensors_similar passed")
def fn():
return sparse_mla_fwd_interface(
q, kv, indices, block_I=block_I, num_stages=num_stages, threads=threads)
return sparse_mla_fwd_interface(q, kv, indices, block_I=block_I, num_stages=num_stages, threads=threads)
from tilelang.profiler import do_bench
......@@ -315,4 +292,5 @@ if __name__ == "__main__":
check_correctness=True,
block_I=64,
num_stages=2,
threads=256)
threads=256,
)
......@@ -9,10 +9,16 @@ import argparse
@tilelang.jit(
out_idx=[-2, -1],
compile_flags=[
"-O3", "-Wno-deprecated-declarations", "-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__", "-U__CUDA_NO_HALF2_OPERATORS__",
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__", "--expt-relaxed-constexpr", "--expt-extended-lambda",
"--ptxas-options=-v,--register-usage-level=10", "-DNDEBUG"
"-O3",
"-Wno-deprecated-declarations",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_HALF2_OPERATORS__",
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"--ptxas-options=-v,--register-usage-level=10",
"-DNDEBUG",
],
)
def sparse_mla_fwd(
......@@ -32,14 +38,12 @@ def sparse_mla_fwd(
num_stages=0,
threads=384,
):
assert dim == tilelang.math.next_power_of_2(
dim), f"haven't check padding correctness yet, dim={dim}"
assert tail_dim == tilelang.math.next_power_of_2(
tail_dim), f"haven't check padding correctness yet, dim={tail_dim}"
assert is_causal == True, 'non-casual is not supported'
assert topk % block_I == 0, 'otherwise will load some index=0 thus causing wrong kv to be loaded'
assert dim == tilelang.math.next_power_of_2(dim), f"haven't check padding correctness yet, dim={dim}"
assert tail_dim == tilelang.math.next_power_of_2(tail_dim), f"haven't check padding correctness yet, dim={tail_dim}"
assert is_causal == True, "non-casual is not supported"
assert topk % block_I == 0, "otherwise will load some index=0 thus causing wrong kv to be loaded"
if sm_scale is None:
sm_scale = (1.0 / (dim + tail_dim))**0.5 * 1.44269504 # log2(e)
sm_scale = (1.0 / (dim + tail_dim)) ** 0.5 * 1.44269504 # log2(e)
else:
sm_scale = sm_scale * 1.44269504 # log2(e)
......@@ -57,15 +61,17 @@ def sparse_mla_fwd(
H = head_kv
padded_H = max(tilelang.math.next_power_of_2(head_kv), 16)
if padded_H != H:
assert kv_group == 1, 'here we solve the H padding automatically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automatically)'
assert kv_group == 1, (
"here we solve the H padding automatically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automatically)"
)
BI = block_I
NI = tilelang.cdiv(topk, block_I)
assert NI % 2 == 0, 'NI should be a multiple of 2'
assert NI % 2 == 0, "NI should be a multiple of 2"
D = dim
D_tail = tail_dim
KV_stride = kv_stride
if head_kv > 64:
assert head_kv % 64 == 0, 'head_kv should be a multiple of 64'
assert head_kv % 64 == 0, "head_kv should be a multiple of 64"
REPLICATE_H = head_kv // 64
else:
REPLICATE_H = 1
......@@ -74,18 +80,14 @@ def sparse_mla_fwd(
@T.prim_func
def main(
Q: T.Tensor(q_shape, dtype), # type: ignore
KV: T.Tensor(kv_shape, dtype), # type: ignore
Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore
q_start_index_s: T.Tensor(1, indices_dtype),
Output: T.Tensor(o_shape, dtype), # type: ignore
Lse: T.Tensor(lse_shape, accum_dtype), # type: ignore
Q: T.Tensor(q_shape, dtype), # type: ignore
KV: T.Tensor(kv_shape, dtype), # type: ignore
Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore
q_start_index_s: T.Tensor(1, indices_dtype),
Output: T.Tensor(o_shape, dtype), # type: ignore
Lse: T.Tensor(lse_shape, accum_dtype), # type: ignore
):
with T.Kernel(
(seq_len - kv_stride + 1 if CP0 else seq_len) * REPLICATE_H,
batch,
kv_group,
threads=threads) as (bx, by, bz):
with T.Kernel((seq_len - kv_stride + 1 if CP0 else seq_len) * REPLICATE_H, batch, kv_group, threads=threads) as (bx, by, bz):
Q_shared_l = T.alloc_shared([H_per_block, D // 2], dtype)
Q_shared_r = T.alloc_shared([H_per_block, D // 2], dtype)
Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype)
......@@ -122,8 +124,7 @@ def sparse_mla_fwd(
bar_sScale_and_sS_free = T.alloc_barrier(arrive_count=256)
b_i, g_i = by, bz
s_i = (bx + (KV_stride - 1 if CP0 else 0)) if REPLICATE_H == 1 else (
bx // REPLICATE_H + (KV_stride - 1 if CP0 else 0))
s_i = (bx + (KV_stride - 1 if CP0 else 0)) if REPLICATE_H == 1 else (bx // REPLICATE_H + (KV_stride - 1 if CP0 else 0))
q_i = q_start_index_s[0] + s_i
max_kv_i = (q_i + 1 - KV_stride) // KV_stride
......@@ -132,26 +133,24 @@ def sparse_mla_fwd(
tx = T.get_thread_binding()
T.copy(Q[b_i, s_i, H0:H1, 0:D // 2], Q_shared_l)
T.copy(Q[b_i, s_i, H0:H1, D // 2:D], Q_shared_r)
T.copy(Q[b_i, s_i, H0:H1, 0 : D // 2], Q_shared_l)
T.copy(Q[b_i, s_i, H0:H1, D // 2 : D], Q_shared_r)
T.copy(Q[b_i, s_i, H0:H1, D:], Q_tail_shared)
T.barrier_arrive(bar_q)
if tx < 128:
T.set_max_nreg(240, 1)
T.fill(sumexp, 0)
T.fill(m_i, -2**30) # avoid -inf - inf to cause nan
T.fill(m_i, -(2**30)) # avoid -inf - inf to cause nan
T.fill(acc_o_l, 0)
T.barrier_wait(bar_q, 0)
for i_i in T.serial(T.ceildiv(NI, 2)):
# Buffer 0
T.barrier_wait(bar_k_0_ready[0], (i_i & 1))
for h_i, bi_i in T.Parallel(H_per_block, BI):
acc_s[h_i, bi_i] = T.if_then_else(is_kv_valid[bi_i], 0,
-T.infinity(acc_s.dtype))
acc_s[h_i, bi_i] = T.if_then_else(is_kv_valid[bi_i], 0, -T.infinity(acc_s.dtype))
T.gemm(Q_shared_l, KV_shared_0_l, acc_s, transpose_B=True, wg_wait=-1)
T.gemm(Q_shared_r, KV_shared_0_r, acc_s, transpose_B=True, wg_wait=-1)
T.gemm(Q_tail_shared, K_tail_shared_0, acc_s, transpose_B=True, wg_wait=-1)
......@@ -187,8 +186,7 @@ def sparse_mla_fwd(
T.barrier_wait(bar_k_1_ready[0], (i_i & 1))
for h_i, bi_i in T.Parallel(H_per_block, BI):
acc_s[h_i, bi_i] = T.if_then_else(is_kv_valid[bi_i], 0,
-T.infinity(acc_s.dtype))
acc_s[h_i, bi_i] = T.if_then_else(is_kv_valid[bi_i], 0, -T.infinity(acc_s.dtype))
T.gemm(Q_shared_l, KV_shared_1_l, acc_s, transpose_B=True, wg_wait=-1)
T.gemm(Q_shared_r, KV_shared_1_r, acc_s, transpose_B=True, wg_wait=-1)
T.gemm(Q_tail_shared, K_tail_shared_1, acc_s, transpose_B=True, wg_wait=-1)
......@@ -227,7 +225,7 @@ def sparse_mla_fwd(
for h_i in T.Parallel(H_per_block):
sumexp[h_i] = T.log2(sumexp[h_i]) + m_i[h_i] * sm_scale
T.copy(acc_o_l, O_shared_l)
T.copy(O_shared_l, Output[b_i, s_i, H0:H1, 0:D // 2])
T.copy(O_shared_l, Output[b_i, s_i, H0:H1, 0 : D // 2])
elif tx >= 128 and tx < 256:
T.set_max_nreg(168, 1)
......@@ -257,7 +255,7 @@ def sparse_mla_fwd(
acc_o_r[h_i, d_i] /= sum_exp_shared[h_i]
T.copy(acc_o_r, O_shared_r)
T.copy(O_shared_r, Output[b_i, s_i, H0:H1, D // 2:D])
T.copy(O_shared_r, Output[b_i, s_i, H0:H1, D // 2 : D])
elif tx >= 256:
# producer
T.set_max_nreg(80, 0)
......@@ -265,70 +263,58 @@ def sparse_mla_fwd(
# Buffer 0
T.barrier_wait(bar_k_0_free[0], ((i_i & 1) ^ 1))
for r in T.serial(4):
indices_local[0] = Indices[b_i, s_i, g_i,
(i_i * 2) * BI + r * 16 + (tx - 256) // 8]
indices_local[0] = Indices[b_i, s_i, g_i, (i_i * 2) * BI + r * 16 + (tx - 256) // 8]
is_kv_valid[r * 16 + (tx - 256) // 8] = indices_local[0] <= max_kv_i
if is_kv_valid[r * 16 + (tx - 256) // 8]:
with T.attr("default", "async_scope", 1):
for u in T.serial(4):
for v in T.vectorized(8):
KV_shared_0_l[r * 16 + (tx - 256) // 8,
64 * u + (tx - 256) % 8 * 8 +
v] = KV[b_i, indices_local[0], g_i,
64 * u + (tx - 256) % 8 * 8 + v]
KV_shared_0_r[r * 16 + (tx - 256) // 8,
64 * u + (tx - 256) % 8 * 8 +
v] = KV[b_i, indices_local[0], g_i, D // 2 +
64 * u + (tx - 256) % 8 * 8 + v]
KV_shared_0_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[
b_i, indices_local[0], g_i, 64 * u + (tx - 256) % 8 * 8 + v
]
KV_shared_0_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[
b_i, indices_local[0], g_i, D // 2 + 64 * u + (tx - 256) % 8 * 8 + v
]
with T.attr("default", "async_scope", 1):
for v in T.vectorized(8):
K_tail_shared_0[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 +
v] = KV[b_i, indices_local[0], g_i,
D + (tx - 256) % 8 * 8 + v]
K_tail_shared_0[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v] = KV[
b_i, indices_local[0], g_i, D + (tx - 256) % 8 * 8 + v
]
T.cp_async_barrier_noinc(bar_k_0_ready[0])
# Buffer 1
T.barrier_wait(bar_k_1_free[0], ((i_i & 1) ^ 1))
for r in T.serial(4):
indices_local[0] = Indices[b_i, s_i, g_i,
(i_i * 2 + 1) * BI + r * 16 + (tx - 256) // 8]
indices_local[0] = Indices[b_i, s_i, g_i, (i_i * 2 + 1) * BI + r * 16 + (tx - 256) // 8]
is_kv_valid[r * 16 + (tx - 256) // 8] = indices_local[0] <= max_kv_i
if is_kv_valid[r * 16 + (tx - 256) // 8]:
with T.attr("default", "async_scope", 1):
for u in T.serial(4):
for v in T.vectorized(8):
KV_shared_1_l[r * 16 + (tx - 256) // 8,
64 * u + (tx - 256) % 8 * 8 +
v] = KV[b_i, indices_local[0], g_i,
64 * u + (tx - 256) % 8 * 8 + v]
KV_shared_1_r[r * 16 + (tx - 256) // 8,
64 * u + (tx - 256) % 8 * 8 +
v] = KV[b_i, indices_local[0], g_i, D // 2 +
64 * u + (tx - 256) % 8 * 8 + v]
KV_shared_1_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[
b_i, indices_local[0], g_i, 64 * u + (tx - 256) % 8 * 8 + v
]
KV_shared_1_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[
b_i, indices_local[0], g_i, D // 2 + 64 * u + (tx - 256) % 8 * 8 + v
]
with T.attr("default", "async_scope", 1):
for v in T.vectorized(8):
K_tail_shared_1[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 +
v] = KV[b_i, indices_local[0], g_i,
D + (tx - 256) % 8 * 8 + v]
K_tail_shared_1[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v] = KV[
b_i, indices_local[0], g_i, D + (tx - 256) % 8 * 8 + v
]
T.cp_async_barrier_noinc(bar_k_1_ready[0])
return main
def sparse_mla_fwd_interface(q,
kv,
indices,
q_start_index_s,
kv_stride,
sm_scale=None,
is_casual=True,
return_kernel=False,
print_kernel=False):
def sparse_mla_fwd_interface(
q, kv, indices, q_start_index_s, kv_stride, sm_scale=None, is_casual=True, return_kernel=False, print_kernel=False
):
assert q.is_contiguous() and kv.is_contiguous() and indices.is_contiguous()
batch, seq_len, heads, dim_plus_tail_dim = q.shape
_, seq_len_kv, kv_group, _ = kv.shape
assert dim_plus_tail_dim == 576, 'you should assign dim otherwise'
assert dim_plus_tail_dim == 576, "you should assign dim otherwise"
dim = 512
assert kv.shape[-1] == dim_plus_tail_dim
......@@ -338,29 +324,23 @@ def sparse_mla_fwd_interface(q,
assert indices.shape == (batch, seq_len, kv_group, topk)
if q_start_index_s != 0:
assert q_start_index_s > kv_stride, "If it is because each cp has too short length, you should fix the logic involving CP0 (cp_rank == 0), to make sure q with pos < KV_Stride - 1 is masked (or you may just ignore how this is handled if nan in these q's Out would not effect others, which is reported to be likely to happen by wangding)"
assert q_start_index_s > kv_stride, (
"If it is because each cp has too short length, you should fix the logic involving CP0 (cp_rank == 0), to make sure q with pos < KV_Stride - 1 is masked (or you may just ignore how this is handled if nan in these q's Out would not effect others, which is reported to be likely to happen by wangding)"
)
CP0 = q_start_index_s == 0
kernel = sparse_mla_fwd(batch, seq_len, seq_len_kv, heads, dim, tail_dim, topk, kv_stride,
kv_group, sm_scale, is_casual, CP0)
kernel = sparse_mla_fwd(batch, seq_len, seq_len_kv, heads, dim, tail_dim, topk, kv_stride, kv_group, sm_scale, is_casual, CP0)
if print_kernel:
print(kernel.get_kernel_source())
out, lse = kernel(q, kv, indices,
torch.tensor([q_start_index_s], dtype=torch.int32, device="cuda"))
out, lse = kernel(q, kv, indices, torch.tensor([q_start_index_s], dtype=torch.int32, device="cuda"))
if return_kernel:
return kernel
if q_start_index_s == 0 and kv_stride > 1:
out[:, :kv_stride - 1, :, :] = 0
out[:, : kv_stride - 1, :, :] = 0
return out, lse
def ref_sparse_mla_fwd_interface(q,
kv,
indices,
q_start_index_s,
kv_stride=4,
sm_scale=None,
is_casual=True):
def ref_sparse_mla_fwd_interface(q, kv, indices, q_start_index_s, kv_stride=4, sm_scale=None, is_casual=True):
q = q.float()
kv = kv.float()
indices = indices.transpose(1, 2)
......@@ -369,7 +349,7 @@ def ref_sparse_mla_fwd_interface(q,
if q_start_index_s is None:
q_start_index_s = sk * kv_stride - sq
assert kv.shape[-1] == 576, 'you should assign dim otherwise'
assert kv.shape[-1] == 576, "you should assign dim otherwise"
dim = 512
k = kv
v = kv[..., :dim]
......@@ -378,15 +358,14 @@ def ref_sparse_mla_fwd_interface(q,
num_kv_per_index = 1
g_index = g
h_index = h // g
compressed_casual_mask = torch.arange(
q_start_index_s, sq + q_start_index_s, dtype=torch.int32,
device="cuda").view(-1, 1) >= torch.arange(
kv_stride - 1, sk * kv_stride, kv_stride, dtype=torch.int32, device="cuda").view(1, -1)
compressed_casual_mask = torch.arange(q_start_index_s, sq + q_start_index_s, dtype=torch.int32, device="cuda").view(
-1, 1
) >= torch.arange(kv_stride - 1, sk * kv_stride, kv_stride, dtype=torch.int32, device="cuda").view(1, -1)
mask = q.new_zeros(b, g_index, sq, sk + 1, dtype=torch.bool).scatter(3, indices.long(), 1)
mask = mask[..., :-1]
mask = mask & compressed_casual_mask.view(1, 1, sq, sk)
mask[:, :, :kv_stride - 1, 0] = True
mask[:, :, : kv_stride - 1, 0] = True
mask = mask.view(b, g_index, 1, sq, sk)
q = q.view(b, sq, g, -1, dim_q)
......@@ -401,41 +380,32 @@ def ref_sparse_mla_fwd_interface(q,
return o.to(torch.bfloat16)
def test_sparse_mla_fwd_pipelined(B=1,
S=4096,
SKV=8192,
H=128,
HKV=1,
DQK=576,
DV=512,
topk=2048,
dtype=torch.bfloat16,
q_start_s_index=1024,
check_correctness=True):
def test_sparse_mla_fwd_pipelined(
B=1, S=4096, SKV=8192, H=128, HKV=1, DQK=576, DV=512, topk=2048, dtype=torch.bfloat16, q_start_s_index=1024, check_correctness=True
):
KV_stride = 1
torch.random.manual_seed(0)
q = torch.randn((B, S, H, DQK), dtype=dtype, device='cuda').requires_grad_(True) / 10
kv = torch.randn((B, SKV, HKV, DQK), dtype=dtype, device='cuda').requires_grad_(True) / 10
q = torch.randn((B, S, H, DQK), dtype=dtype, device="cuda").requires_grad_(True) / 10
kv = torch.randn((B, SKV, HKV, DQK), dtype=dtype, device="cuda").requires_grad_(True) / 10
q_start_s_index_t = torch.tensor([q_start_s_index], dtype=torch.int32, device="cuda")
q.clamp_(-10, 10)
kv.clamp_(-10, 10)
indices = torch.full((B, S, HKV, topk), SKV, dtype=torch.int32, device='cuda')
indices = torch.full((B, S, HKV, topk), SKV, dtype=torch.int32, device="cuda")
for b in range(B):
for t in range(S):
for h in range(HKV):
i_i = torch.randperm(min(max(1, ((t + q_start_s_index) // KV_stride)), SKV))[:topk]
indices[b, t, h, :len(i_i)] = i_i
indices[b, t, h, : len(i_i)] = i_i
kernel = sparse_mla_fwd_interface(
q, kv, indices, q_start_s_index, KV_stride, return_kernel=True, print_kernel=True)
kernel = sparse_mla_fwd_interface(q, kv, indices, q_start_s_index, KV_stride, return_kernel=True, print_kernel=True)
def fn():
out, lse = kernel(q, kv, indices, q_start_s_index_t)
if q_start_s_index == 0 and KV_stride > 1:
out[:, :KV_stride - 1, :, :] = 0
out[:, : KV_stride - 1, :, :] = 0
return out, lse
tl_out, tl_lse = fn()
......@@ -446,14 +416,15 @@ def test_sparse_mla_fwd_pipelined(B=1,
torch.testing.assert_close(tl_out, ref_out, rtol=1e-3, atol=1e-3)
from tilelang.profiler import do_bench
ms = do_bench(
fn,
rep=10,
warmup=10,
)
print(f"Average time: {ms:.3f} ms")
print(f'fwd io bandwidth = ', (B * S * DQK * topk * 2) / (ms * 1e-3) / 1e12)
print(f'fwd tflops = ', (B * S * (DQK + DV) * topk * 2 * H) / (ms * 1e-3) / 1e12)
print(f"fwd io bandwidth = ", (B * S * DQK * topk * 2) / (ms * 1e-3) / 1e12)
print(f"fwd tflops = ", (B * S * (DQK + DV) * topk * 2 * H) / (ms * 1e-3) / 1e12)
if __name__ == "__main__":
......@@ -464,5 +435,4 @@ if __name__ == "__main__":
B, S, SKV, H, HKV, DQK, DV, topk, dtype = 1, 1024, 8192, 128, 1, 576, 512, 2048, torch.bfloat16
else:
B, S, SKV, H, HKV, DQK, DV, topk, dtype = 1, 4096, 8192, 128, 1, 576, 512, 2048, torch.bfloat16
test_sparse_mla_fwd_pipelined(
B, S, SKV, H, HKV, DQK, DV, topk, dtype, check_correctness=args.test_correctness)
test_sparse_mla_fwd_pipelined(B, S, SKV, H, HKV, DQK, DV, topk, dtype, check_correctness=args.test_correctness)
......@@ -21,23 +21,20 @@ def test_example_fp8_lighting_indexer():
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_sparse_mla_fwd():
# small shapes for testing
sparse_mla_fwd.test_sparse_mla_fwd(
S=256, SKV=1024, H=64, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False)
sparse_mla_fwd.test_sparse_mla_fwd(S=256, SKV=1024, H=64, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False)
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_sparse_mla_fwd_pipelined():
# small shapes for testing
sparse_mla_fwd_pipelined.test_sparse_mla_fwd_pipelined(
S=256, SKV=512, H=64, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False)
sparse_mla_fwd_pipelined.test_sparse_mla_fwd_pipelined(S=256, SKV=512, H=64, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False)
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_sparse_mla_bwd():
sparse_mla_bwd.test_sparse_mla_bwd(
S=256, SKV=512, H=64, HKV=1, DQKV=576, DV=512, topk=256, check_correctness=False)
sparse_mla_bwd.test_sparse_mla_bwd(S=256, SKV=512, H=64, HKV=1, DQKV=576, DV=512, topk=256, check_correctness=False)
if __name__ == "__main__":
......
......@@ -127,9 +127,9 @@ def tl_topk_impl(topk, in_dtype="float32", out_dtype="int32"):
l_num_input = s_num_input[r_idx]
for s in T.serial(T.ceildiv(l_num_input, BLOCK_SIZE)):
if s * BLOCK_SIZE + tx < l_num_input:
l_bin_id32 = T.Cast("int32", ((
convert_to_uint32(input[bx, s_input_idx[r_idx, s * BLOCK_SIZE + tx]]) >>
(24 - round * 8)) & 0xFF))
l_bin_id32 = T.Cast(
"int32", ((convert_to_uint32(input[bx, s_input_idx[r_idx, s * BLOCK_SIZE + tx]]) >> (24 - round * 8)) & 0xFF)
)
T.atomic_add(s_histogram[l_bin_id32], 1)
T.sync_threads()
# cumsum
......@@ -156,23 +156,20 @@ def tl_topk_impl(topk, in_dtype="float32", out_dtype="int32"):
for s in T.serial(T.ceildiv(l_num_input, BLOCK_SIZE)):
T.sync_threads()
if s * BLOCK_SIZE + tx < l_num_input:
l_bin_id32 = T.Cast("int32", ((
convert_to_uint32(input[bx, s_input_idx[r_idx, s * BLOCK_SIZE + tx]]) >>
(24 - round * 8)) & 0xFF))
l_bin_id32 = T.Cast(
"int32", ((convert_to_uint32(input[bx, s_input_idx[r_idx, s * BLOCK_SIZE + tx]]) >> (24 - round * 8)) & 0xFF)
)
if l_bin_id32 > l_threshold_bin_id:
pos = T.atomic_add(
s_histogram[l_bin_id32 + 1], 1, return_prev=True) + l_start_pos
pos = T.atomic_add(s_histogram[l_bin_id32 + 1], 1, return_prev=True) + l_start_pos
index[bx, pos] = s_input_idx[r_idx, s * BLOCK_SIZE + tx]
elif l_bin_id32 == l_threshold_bin_id and l_new_topk > 0:
if round == 3:
l_out_pos = T.atomic_add(
s_histogram[l_bin_id32 + 1], 1, return_prev=True) + l_start_pos
l_out_pos = T.atomic_add(s_histogram[l_bin_id32 + 1], 1, return_prev=True) + l_start_pos
if l_out_pos < topk:
index[bx, l_out_pos] = s_input_idx[r_idx, s * BLOCK_SIZE + tx]
else:
pos = T.atomic_add(s_num_input[r_idx ^ 1], 1, return_prev=True)
s_input_idx[r_idx ^ 1, pos] = s_input_idx[r_idx,
s * BLOCK_SIZE + tx]
s_input_idx[r_idx ^ 1, pos] = s_input_idx[r_idx, s * BLOCK_SIZE + tx]
return tl_topk_kernel
......@@ -186,7 +183,6 @@ def tl_topk(input, starts, ends, topk):
def test_topk_selector(batch=64, seq_len=32 * 1024, topk=2048):
batch = 64
seq_len = 32 * 1024
topk = 2048
......@@ -212,8 +208,7 @@ def test_topk_selector(batch=64, seq_len=32 * 1024, topk=2048):
set_ref = set(ref_np)
set_trt = set(trt_np)
intersection = set_ref & set_trt
print("selected/all:", len(intersection), "/", len(set_ref), "=",
len(intersection) / len(set_ref))
print("selected/all:", len(intersection), "/", len(set_ref), "=", len(intersection) / len(set_ref))
# Performance test with CUDA events
......
......@@ -23,8 +23,7 @@ def _is_equal(a, b):
if isinstance(a, torch.Tensor):
return a is b
# Whitelist of types that are safe to compare by value for caching.
if isinstance(a, (int, float, str, bool, type(None))) and isinstance(
b, (int, float, str, bool, type(None))):
if isinstance(a, (int, float, str, bool, type(None))) and isinstance(b, (int, float, str, bool, type(None))):
return a == b
# For other types, we cannot guarantee a cheap and safe comparison, so we fail the cache check.
return False
......@@ -58,9 +57,11 @@ def tensor_cache(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]
if len(args) == len(last_args) and len(kwargs) == len(last_kwargs):
# For Tensors, check for object identity. For other types, check for equality.
# Python caches small integers, so `is` works for them but not for large integers like 4096.
if all(_is_equal(a, b) for a, b in zip(args, last_args)) and \
set(kwargs.keys()) == set(last_kwargs.keys()) and \
all(_is_equal(v, last_kwargs[k]) for k, v in kwargs.items()):
if (
all(_is_equal(a, b) for a, b in zip(args, last_args))
and set(kwargs.keys()) == set(last_kwargs.keys())
and all(_is_equal(v, last_kwargs[k]) for k, v in kwargs.items())
):
return last_result
result = fn(*args, **kwargs)
......@@ -79,73 +80,68 @@ def cal_seq_idx_from_cu_seqlens(cu_seqlens: torch.LongTensor, seq_len: int):
@tensor_cache
def cal_seq_idx_for_q(cu_seqlens_qs: torch.LongTensor, cu_seqlens_qe: torch.LongTensor,
seq_len: int) -> torch.IntTensor:
seq_idx_for_q = torch.full((seq_len,),
len(cu_seqlens_qs),
dtype=torch.int32,
device=cu_seqlens_qs.device)
def cal_seq_idx_for_q(cu_seqlens_qs: torch.LongTensor, cu_seqlens_qe: torch.LongTensor, seq_len: int) -> torch.IntTensor:
seq_idx_for_q = torch.full((seq_len,), len(cu_seqlens_qs), dtype=torch.int32, device=cu_seqlens_qs.device)
for i in range(len(cu_seqlens_qs)):
seq_idx_for_q[cu_seqlens_qs[i]:cu_seqlens_qe[i]] = i
seq_idx_for_q[cu_seqlens_qs[i] : cu_seqlens_qe[i]] = i
return seq_idx_for_q
@tensor_cache
def cal_cu_seqlen_ks_for_q(cu_seqlens_qs: torch.LongTensor, cu_seqlens_qe: torch.LongTensor,
cu_seqlens_ks: torch.LongTensor, seq_len: int) -> torch.IntTensor:
def cal_cu_seqlen_ks_for_q(
cu_seqlens_qs: torch.LongTensor, cu_seqlens_qe: torch.LongTensor, cu_seqlens_ks: torch.LongTensor, seq_len: int
) -> torch.IntTensor:
cu_seqlen_ks_for_each_q = torch.gather(
input=torch.cat([
cu_seqlens_ks,
torch.full((1,),
torch.iinfo(torch.int32).max,
dtype=torch.int32,
device=cu_seqlens_qs.device)
]),
input=torch.cat([cu_seqlens_ks, torch.full((1,), torch.iinfo(torch.int32).max, dtype=torch.int32, device=cu_seqlens_qs.device)]),
dim=0,
index=cal_seq_idx_for_q(
cu_seqlens_qs=cu_seqlens_qs, cu_seqlens_qe=cu_seqlens_qe, seq_len=seq_len).long())
index=cal_seq_idx_for_q(cu_seqlens_qs=cu_seqlens_qs, cu_seqlens_qe=cu_seqlens_qe, seq_len=seq_len).long(),
)
return cu_seqlen_ks_for_each_q.int()
@tensor_cache
def cal_cu_seqlen_ke_for_q(cu_seqlens_qs: torch.LongTensor, cu_seqlens_qe: torch.LongTensor,
cu_seqlens_ks: torch.LongTensor, cu_seqlens_ke: torch.LongTensor,
q_start_idxs: torch.LongTensor, seq_len: int,
kv_stride: int) -> torch.IntTensor:
def cal_cu_seqlen_ke_for_q(
cu_seqlens_qs: torch.LongTensor,
cu_seqlens_qe: torch.LongTensor,
cu_seqlens_ks: torch.LongTensor,
cu_seqlens_ke: torch.LongTensor,
q_start_idxs: torch.LongTensor,
seq_len: int,
kv_stride: int,
) -> torch.IntTensor:
cu_seqlen_ke_for_each_q = torch.gather(
input=torch.cat(
[cu_seqlens_ke,
torch.zeros(1, dtype=torch.int32, device=cu_seqlens_qs.device)]),
input=torch.cat([cu_seqlens_ke, torch.zeros(1, dtype=torch.int32, device=cu_seqlens_qs.device)]),
dim=0,
index=cal_seq_idx_for_q(
cu_seqlens_qs=cu_seqlens_qs, cu_seqlens_qe=cu_seqlens_qe, seq_len=seq_len).long())
casual_cu_seqlen_ke_for_each_q = torch.zeros((seq_len,),
dtype=torch.int32,
device=cu_seqlens_qs.device)
index=cal_seq_idx_for_q(cu_seqlens_qs=cu_seqlens_qs, cu_seqlens_qe=cu_seqlens_qe, seq_len=seq_len).long(),
)
casual_cu_seqlen_ke_for_each_q = torch.zeros((seq_len,), dtype=torch.int32, device=cu_seqlens_qs.device)
for i in range(len(cu_seqlens_qs)):
casual_cu_seqlen_ke_for_each_q[cu_seqlens_qs[i]:cu_seqlens_qe[i]] = (torch.arange(
q_start_idxs[i],
q_start_idxs[i] + cu_seqlens_qe[i] - cu_seqlens_qs[i],
dtype=torch.int32,
device=cu_seqlens_qs.device) + 1) // kv_stride + cu_seqlens_ks[i]
casual_cu_seqlen_ke_for_each_q[cu_seqlens_qs[i] : cu_seqlens_qe[i]] = (
torch.arange(
q_start_idxs[i], q_start_idxs[i] + cu_seqlens_qe[i] - cu_seqlens_qs[i], dtype=torch.int32, device=cu_seqlens_qs.device
)
+ 1
) // kv_stride + cu_seqlens_ks[i]
cu_seqlen_ke_for_each_q = torch.minimum(casual_cu_seqlen_ke_for_each_q, cu_seqlen_ke_for_each_q)
return cu_seqlen_ke_for_each_q.int()
@tensor_cache
def cal_ks_ke_from_cu_seqlen_qk(cu_seqlens_q: torch.LongTensor,
cu_seqlens_k: torch.LongTensor = None,
offs_q: torch.LongTensor = None,
*,
seq_len: int,
kv_stride: int = 1,
cp_rank: int = 0,
cp_size: int = 1,
balanced_cp=False):
'''
def cal_ks_ke_from_cu_seqlen_qk(
cu_seqlens_q: torch.LongTensor,
cu_seqlens_k: torch.LongTensor = None,
offs_q: torch.LongTensor = None,
*,
seq_len: int,
kv_stride: int = 1,
cp_rank: int = 0,
cp_size: int = 1,
balanced_cp=False,
):
"""
seq_len: seq len per cp rank
balanced cp slice assignment: 0 1 2 3 3 2 1 0
'''
"""
n_seq = len(cu_seqlens_q) - 1
assert n_seq > 0
assert cu_seqlens_q.shape == (n_seq + 1,)
......@@ -170,10 +166,12 @@ def cal_ks_ke_from_cu_seqlen_qk(cu_seqlens_q: torch.LongTensor,
def f(x: torch.Tensor):
chunks = x.chunk(cp_size * 2)
return torch.cat([
chunks[cp_rank],
chunks[cp_size - cp_rank - 1],
])
return torch.cat(
[
chunks[cp_rank],
chunks[cp_size - cp_rank - 1],
]
)
ks = f(ks)
ke = f(ke)
......@@ -189,8 +187,7 @@ def ceil_to_ue8m0(x: torch.Tensor):
return torch.pow(2.0, torch.ceil(torch.log2(x.abs())))
def per_custom_dims_cast_to_fp8(x: torch.Tensor, dims: Tuple[int],
use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]:
def per_custom_dims_cast_to_fp8(x: torch.Tensor, dims: Tuple[int], use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]:
excluded_dims = tuple([i for i in range(x.dim()) if i not in set(dims)])
x_amax = x.abs().float().amax(dim=excluded_dims, keepdim=True).clamp(1e-4)
sf = x_amax / 448.0
......@@ -239,14 +236,18 @@ def generate_random_cu_seqlens(per_cp_seqlen, cp_size=4, cp_rank=3, kv_stride=1,
total_seqlen - (cp_rank + 1) * per_chunk_seqlen,
total_seqlen - cp_rank * per_chunk_seqlen,
)
ks = torch.cat([
cu_seqlens_ks_for_each_q[slice_short],
cu_seqlens_ks_for_each_q[slice_long],
])
ke = torch.cat([
cu_seqlens_ke_for_each_q[slice_short],
cu_seqlens_ke_for_each_q[slice_long],
])
ks = torch.cat(
[
cu_seqlens_ks_for_each_q[slice_short],
cu_seqlens_ks_for_each_q[slice_long],
]
)
ke = torch.cat(
[
cu_seqlens_ke_for_each_q[slice_short],
cu_seqlens_ke_for_each_q[slice_long],
]
)
assert len(ks) == len(ke) == per_cp_seqlen
return ks, ke
......@@ -302,11 +303,9 @@ def assert_tensors_similar(x, y, eps=1e-8, name="tensor", raise_assert=True):
raise_assert: Whether to raise assertion error on failure
"""
sim = calculate_tensor_similarity(x, y, name)
diff = 1. - sim
diff = 1.0 - sim
if not (0 <= diff <= eps):
print(
f"\033[31mERROR: {name} similarity check failed, diff={diff:.2e} (threshold={eps:.2e})\033[0m"
)
print(f"\033[31mERROR: {name} similarity check failed, diff={diff:.2e} (threshold={eps:.2e})\033[0m")
if raise_assert:
assert False # noqa: B011
......@@ -316,11 +315,8 @@ if __name__ == "__main__":
cu_seqlens = torch.randint(128, 4096, (1000,), dtype=torch.int32, device="cuda")
last_idx = torch.where(cu_seqlens.cumsum(dim=0) >= seq_len)[0][0]
cu_seqlens_cumsum = cu_seqlens[:last_idx].cumsum(dim=0)
cu_seqlens_qs = torch.cat(
[torch.zeros(1, dtype=torch.int32, device=cu_seqlens.device), cu_seqlens_cumsum])
cu_seqlens_qe = torch.cat(
[cu_seqlens_cumsum,
torch.ones(1, dtype=torch.int32, device=cu_seqlens.device) * seq_len])
cu_seqlens_qs = torch.cat([torch.zeros(1, dtype=torch.int32, device=cu_seqlens.device), cu_seqlens_cumsum])
cu_seqlens_qe = torch.cat([cu_seqlens_cumsum, torch.ones(1, dtype=torch.int32, device=cu_seqlens.device) * seq_len])
from tilelang.profiler import do_bench
......
......@@ -39,12 +39,10 @@ def torch_convert_bit_twiddling(tensor):
res0 = val_concat_expanded & mask
res1 = (val_concat_expanded << 3) & mask
res2 = (val_concat_expanded << 6) & mask
res3 = ((val_concat_expanded << 1) & mask1) | ((val_concat_expanded >> 3) & mask2) | (
(val_concat_expanded >> 7) & mask3)
res3 = ((val_concat_expanded << 1) & mask1) | ((val_concat_expanded >> 3) & mask2) | ((val_concat_expanded >> 7) & mask3)
# Select the correct result based on position
bf16 = torch.where(pos == 0, res0, torch.where(pos == 1, res1,
torch.where(pos == 2, res2, res3)))
bf16 = torch.where(pos == 0, res0, torch.where(pos == 1, res1, torch.where(pos == 2, res2, res3)))
# Convert to uint16 for .view(torch.bfloat16)
bf16_uint16 = (bf16 & 0xFFFF).to(torch.uint16)
......@@ -110,7 +108,7 @@ def print_bit(name, val):
val (torch.Tensor): A scalar PyTorch tensor (numeric) whose 32-bit binary representation will be shown.
"""
val_cpu = val.cpu().item()
binary_repr = f'{val_cpu:032b}'
binary_repr = f"{val_cpu:032b}"
print(name, binary_repr)
......@@ -122,7 +120,7 @@ def calc_sim(x, y, name="tensor"):
x, y = x.data.double(), y.data.double()
denominator = (x * x + y * y).sum()
if denominator == 0:
print_red_warning(f'{name} all zero')
print_red_warning(f"{name} all zero")
return 1
sim = 2 * (x * y).sum() / denominator
return sim
......@@ -132,21 +130,19 @@ def assert_similar(x, y, eps=1e-8, name="tensor", data="", raise_assert=True):
x_mask = torch.isfinite(x)
y_mask = torch.isfinite(y)
if not torch.all(x_mask == y_mask):
print_red_warning(f'{name} Error: isfinite mask mismatch')
print_red_warning(f"{name} Error: isfinite mask mismatch")
if raise_assert:
raise AssertionError
if not torch.isclose(
x.masked_fill(x_mask, 0), y.masked_fill(y_mask, 0), rtol=0, atol=0,
equal_nan=True).all():
print_red_warning(f'{name} Error: nonfinite value mismatch')
if not torch.isclose(x.masked_fill(x_mask, 0), y.masked_fill(y_mask, 0), rtol=0, atol=0, equal_nan=True).all():
print_red_warning(f"{name} Error: nonfinite value mismatch")
if raise_assert:
raise AssertionError
x = x.masked_fill(~x_mask, 0)
y = y.masked_fill(~y_mask, 0)
sim = calc_sim(x, y, name)
diff = (1. - sim).item()
print(f'{diff=}')
diff = (1.0 - sim).item()
print(f"{diff=}")
if not (0 <= diff <= eps):
print_red_warning(f'{name} Error: {diff=}')
print_red_warning(f"{name} Error: {diff=}")
if raise_assert:
raise AssertionError
......@@ -24,6 +24,7 @@ def get_configs():
the parameter name to its chosen value.
"""
import itertools
iter_params = dict(
block_M=[64, 128, 256],
block_N=[64, 128, 256],
......@@ -32,63 +33,62 @@ def get_configs():
threads=[128, 256, 512],
split=[1, 2],
)
return [{
k: v for k, v in zip(iter_params, values)
} for values in itertools.product(*iter_params.values())]
return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())]
@tilelang.autotune(configs=get_configs(),)
@tilelang.autotune(
configs=get_configs(),
)
@tilelang.jit(
out_idx=[-1],
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True
},
pass_configs={tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True},
)
def matmul(M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
source_format='uint',
num_bits=4,
fast_dequant=True,
block_M=256,
block_N=128,
block_K=128,
num_stages=2,
threads=256,
split=1):
def matmul(
M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
source_format="uint",
num_bits=4,
fast_dequant=True,
block_M=256,
block_N=128,
block_K=128,
num_stages=2,
threads=256,
split=1,
):
"""
Builds a parameterized TileLang/TIR matrix-multiplication kernel that dequantizes 4-bit FP inputs to BF16 on-the-fly and computes C = A @ B^T.
This function returns a tiled, autotunable prim_func implementing a block-wise GEMM with shared-memory buffering and a pipelined K-loop. The kernel accepts:
- A: dense input of shape (M, K) with dtype `in_dtype`.
- B: packed quantized input of shape (N, QK) where QK = K / (8 / num_bits) stored as `uint8`.
- C: output of shape (M, N) with dtype `out_dtype`.
The generated kernel supports two dequantization paths:
- fast_dequant (fast_dequant=True): calls an external mxfp dequantization intrinsic (twiddling-based) loaded from a C source returned by get_mxfp_intrin_group.
- simple dequant (fast_dequant=False): performs a pure-TIR FP4 -> BF16 conversion per element.
Important behavior and requirements:
- num_bits (default 4) is the bit-width of the quantized elements; storage_dtype is uint8 and num_elems_per_byte = 8 // num_bits.
- QK = K // num_elems_per_byte and Block_QK = block_K // num_elems_per_byte determine B and shared-buffer shapes.
- Asserts that K % (block_K * split) == 0; K must be divisible by block_K * split for the tiling to be valid.
- When fast_dequant is True, a valid mxfp intrinsic group (C source and function name) must be available via tilelang.quantize.get_mxfp_intrin_group.
- The kernel launches a 2D grid over ceildiv(N, block_N) and ceildiv(M, block_M) and uses `threads` threads per block with `num_stages` pipeline stages.
Parameters that alter kernel layout/behavior (brief):
- block_M, block_N, block_K: tile sizes for M, N, and K dimensions.
- num_stages: number of software pipeline stages for the K-loop.
- threads: number of threads used per kernel block.
- split: extra K-splitting factor; K must be divisible by block_K * split.
- source_format, num_bits: describe the quantized data layout passed to the mxfp intrinsics.
Returns:
A TileLang/TIR prim_func (the compiled `main`) implementing the described dequantize-then-GEMM kernel.
"""
Builds a parameterized TileLang/TIR matrix-multiplication kernel that dequantizes 4-bit FP inputs to BF16 on-the-fly and computes C = A @ B^T.
This function returns a tiled, autotunable prim_func implementing a block-wise GEMM with shared-memory buffering and a pipelined K-loop. The kernel accepts:
- A: dense input of shape (M, K) with dtype `in_dtype`.
- B: packed quantized input of shape (N, QK) where QK = K / (8 / num_bits) stored as `uint8`.
- C: output of shape (M, N) with dtype `out_dtype`.
The generated kernel supports two dequantization paths:
- fast_dequant (fast_dequant=True): calls an external mxfp dequantization intrinsic (twiddling-based) loaded from a C source returned by get_mxfp_intrin_group.
- simple dequant (fast_dequant=False): performs a pure-TIR FP4 -> BF16 conversion per element.
Important behavior and requirements:
- num_bits (default 4) is the bit-width of the quantized elements; storage_dtype is uint8 and num_elems_per_byte = 8 // num_bits.
- QK = K // num_elems_per_byte and Block_QK = block_K // num_elems_per_byte determine B and shared-buffer shapes.
- Asserts that K % (block_K * split) == 0; K must be divisible by block_K * split for the tiling to be valid.
- When fast_dequant is True, a valid mxfp intrinsic group (C source and function name) must be available via tilelang.quantize.get_mxfp_intrin_group.
- The kernel launches a 2D grid over ceildiv(N, block_N) and ceildiv(M, block_M) and uses `threads` threads per block with `num_stages` pipeline stages.
Parameters that alter kernel layout/behavior (brief):
- block_M, block_N, block_K: tile sizes for M, N, and K dimensions.
- num_stages: number of software pipeline stages for the K-loop.
- threads: number of threads used per kernel block.
- split: extra K-splitting factor; K must be divisible by block_K * split.
- source_format, num_bits: describe the quantized data layout passed to the mxfp intrinsics.
Returns:
A TileLang/TIR prim_func (the compiled `main`) implementing the described dequantize-then-GEMM kernel.
"""
num_elems_per_byte = 8 // num_bits
storage_dtype = "uint8"
......@@ -189,8 +189,7 @@ def matmul(M,
# Finally, store the dequantized data to shared memory.
for v in T.vectorized(0, local_size):
index = i * threads * local_size + tx * local_size + v
B_dequantize_shared[index // block_K,
index % block_K] = B_dequantize_local_thread[v]
B_dequantize_shared[index // block_K, index % block_K] = B_dequantize_local_thread[v]
return fast_dequant_bf16_fp4_twiddling
......@@ -215,30 +214,29 @@ def matmul(M,
assert in_dtype in ["fp4"]
assert out_dtype in ["bfloat16"]
def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr,
scale: tir.PrimExpr, dtype: str):
def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale: tir.PrimExpr, dtype: str):
"""
Convert a 4-bit FP4 value packed in a uint8 byte into a bfloat16 value.
This helper extracts the 4-bit field located at the bit position `pos` within the
byte `val`, interprets it as an FP4 (sign, exponent, mantissa) value, applies an
exponent `scale` offset to align it with bfloat16 exponent bias, clamps the
resulting exponent to 8 bits, and returns the assembled bfloat16 bit pattern.
Parameters:
nbit (int): Number of bits in the packed element; must be 4.
val (tir.PrimExpr): A uint8 value containing packed FP4 elements.
pos (tir.PrimExpr): Index (0-based) of which FP4 nibble inside `val` to extract.
scale (tir.PrimExpr): Exponent offset applied when converting FP4 exponent to bfloat16.
dtype (str): Target dtype string; must be "bfloat16".
Returns:
tir.PrimExpr: A bfloat16-typed PrimExpr containing the converted value.
Notes:
- The function asserts `nbit == 4`, `dtype == "bfloat16"`, and that `val.dtype` is "uint8".
- The conversion uses a fixed mapping from FP4 exponent/mantissa layout into bfloat16
bit fields and clamps the computed exponent to fit into 8 bits.
Convert a 4-bit FP4 value packed in a uint8 byte into a bfloat16 value.
This helper extracts the 4-bit field located at the bit position `pos` within the
byte `val`, interprets it as an FP4 (sign, exponent, mantissa) value, applies an
exponent `scale` offset to align it with bfloat16 exponent bias, clamps the
resulting exponent to 8 bits, and returns the assembled bfloat16 bit pattern.
Parameters:
nbit (int): Number of bits in the packed element; must be 4.
val (tir.PrimExpr): A uint8 value containing packed FP4 elements.
pos (tir.PrimExpr): Index (0-based) of which FP4 nibble inside `val` to extract.
scale (tir.PrimExpr): Exponent offset applied when converting FP4 exponent to bfloat16.
dtype (str): Target dtype string; must be "bfloat16".
Returns:
tir.PrimExpr: A bfloat16-typed PrimExpr containing the converted value.
Notes:
- The function asserts `nbit == 4`, `dtype == "bfloat16"`, and that `val.dtype` is "uint8".
- The conversion uses a fixed mapping from FP4 exponent/mantissa layout into bfloat16
bit fields and clamps the computed exponent to fit into 8 bits.
"""
assert nbit == 4
assert dtype == "bfloat16"
......@@ -254,8 +252,9 @@ def matmul(M,
e_bf16 = T.min(e_bf16 + scale, tir.const((1 << 8) - 1, "uint16"))
m_f4 = f4 & tir.const(1, "uint16")
val_bf16 = tir.reinterpret(
"bfloat16", ((((s << tir.const(8, "uint16")) | e_bf16) << tir.const(7, "uint16"))
| (m_f4 << tir.const(6, "uint16"))).astype("uint16"))
"bfloat16",
((((s << tir.const(8, "uint16")) | e_bf16) << tir.const(7, "uint16")) | (m_f4 << tir.const(6, "uint16"))).astype("uint16"),
)
return val_bf16
@T.macro
......@@ -292,32 +291,32 @@ def matmul(M,
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, storage_dtype),
C: T.Tensor((M, N), out_dtype),
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, storage_dtype),
C: T.Tensor((M, N), out_dtype),
):
"""
Kernel entry for the tiled, pipelined matmul used by the generated prim_func.
This function implements a block-wise GEMM over a 2D grid (grid dims: ceildiv(N, block_N) x ceildiv(M, block_M)) with a thread block of `threads`. For each output block it:
- Allocates shared buffers for A, the packed/quantized B, and a dequantized B tile.
- Allocates a fragment accumulator (C_local) and a shared output tile (C_shared) with a swizzled layout.
- Pipelines over K in chunks of `block_K` for `num_stages` stages:
- Loads A and packed B tiles into shared memory.
- Dequantizes B into B_dequantize_shared using either the fast (twiddling/external) or the simple (pure-TIR) dequantization routine.
- Performs a GEMM accumulating into C_local with B transposed.
- Stores the accumulated block from C_local back to the global output C via C_shared.
Parameters:
- A: input tile of shape (M, K) with dtype `in_dtype`.
- B: packed/quantized input of shape (N, QK) with storage dtype `storage_dtype` (quantized FP4 packing).
- C: output tensor of shape (M, N) with dtype `out_dtype`.
Side effects:
- Writes the computed output block into the global tensor `C`.
- Uses and updates shared memory buffers and per-thread accumulators.
No value is returned.
Kernel entry for the tiled, pipelined matmul used by the generated prim_func.
This function implements a block-wise GEMM over a 2D grid (grid dims: ceildiv(N, block_N) x ceildiv(M, block_M)) with a thread block of `threads`. For each output block it:
- Allocates shared buffers for A, the packed/quantized B, and a dequantized B tile.
- Allocates a fragment accumulator (C_local) and a shared output tile (C_shared) with a swizzled layout.
- Pipelines over K in chunks of `block_K` for `num_stages` stages:
- Loads A and packed B tiles into shared memory.
- Dequantizes B into B_dequantize_shared using either the fast (twiddling/external) or the simple (pure-TIR) dequantization routine.
- Performs a GEMM accumulating into C_local with B transposed.
- Stores the accumulated block from C_local back to the global output C via C_shared.
Parameters:
- A: input tile of shape (M, K) with dtype `in_dtype`.
- B: packed/quantized input of shape (N, QK) with storage dtype `storage_dtype` (quantized FP4 packing).
- C: output tensor of shape (M, N) with dtype `out_dtype`.
Side effects:
- Writes the computed output block into the global tensor `C`.
- Uses and updates shared memory buffers and per-thread accumulators.
No value is returned.
"""
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
......@@ -327,9 +326,11 @@ def matmul(M,
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
C_shared = T.alloc_shared((block_M, block_N), out_dtype)
T.annotate_layout({
C_shared: tilelang.layout.make_swizzled_layout(C_shared),
})
T.annotate_layout(
{
C_shared: tilelang.layout.make_swizzled_layout(C_shared),
}
)
T.clear(C_local)
for k in T.Pipelined(K // block_K, num_stages=num_stages):
......@@ -344,7 +345,7 @@ def matmul(M,
T.gemm(A_shared, B_dequantize_shared, C_local, transpose_B=True)
T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M:(by + 1) * block_M, bx * block_N:(bx + 1) * block_N])
T.copy(C_shared, C[by * block_M : (by + 1) * block_M, bx * block_N : (bx + 1) * block_N])
return main
......@@ -409,8 +410,7 @@ def main(m=256, n=256, k=256, fast_dequant=True, tune=False):
"""
total_flops = 2 * m * n * k
if tune:
kernel = matmul(
m, n, k, "bfloat16", "bfloat16", "float32", num_bits=4, fast_dequant=fast_dequant)
kernel = matmul(m, n, k, "bfloat16", "bfloat16", "float32", num_bits=4, fast_dequant=fast_dequant)
else:
kernel = matmul(
m,
......@@ -426,7 +426,8 @@ def main(m=256, n=256, k=256, fast_dequant=True, tune=False):
block_K=128,
num_stages=2,
threads=256,
split=1)
split=1,
)
profiler = kernel.get_profiler(tilelang.TensorSupplyType.Auto)
if fast_dequant:
profiler.assert_allclose(ref_program_twiddling, rtol=0.01, atol=0.01)
......
......@@ -7,29 +7,28 @@ import torch
from dequantize_utils import torch_convert_bit_twiddling, torch_convert
def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale: tir.PrimExpr,
dtype: str):
def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale: tir.PrimExpr, dtype: str):
"""
Convert a 4-bit field packed in a uint8 into a bfloat16 value, applying an exponent scale.
Convert a 4-bit field packed in a uint8 into a bfloat16 value, applying an exponent scale.
This helper extracts a 4-bit nibble from `val` at byte-nibble position `pos`, interprets its
bits as a sign/exponent/mantissa in the 4-bit custom FP4 layout, adjusts the exponent by
`scale` (clamped to an 8-bit range), and assembles the corresponding bfloat16 representation.
This helper extracts a 4-bit nibble from `val` at byte-nibble position `pos`, interprets its
bits as a sign/exponent/mantissa in the 4-bit custom FP4 layout, adjusts the exponent by
`scale` (clamped to an 8-bit range), and assembles the corresponding bfloat16 representation.
Parameters:
nbit (int): Number of bits in the packed field (must be 4).
val (tir.PrimExpr): Packed input value of dtype `uint8` containing one or more 4-bit fields.
pos (tir.PrimExpr): Index of the nibble within `val` (used to shift/extract the 4-bit field).
scale (tir.PrimExpr): Per-element exponent adjustment added to the extracted exponent (uint-like).
dtype (str): Destination dtype string (must be "bfloat16").
Parameters:
nbit (int): Number of bits in the packed field (must be 4).
val (tir.PrimExpr): Packed input value of dtype `uint8` containing one or more 4-bit fields.
pos (tir.PrimExpr): Index of the nibble within `val` (used to shift/extract the 4-bit field).
scale (tir.PrimExpr): Per-element exponent adjustment added to the extracted exponent (uint-like).
dtype (str): Destination dtype string (must be "bfloat16").
Returns:
tir.PrimExpr: The resulting value reinterpreted as `bfloat16`.
Returns:
tir.PrimExpr: The resulting value reinterpreted as `bfloat16`.
Notes:
- Preconditions are enforced via assertions: nbit == 4, dtype == "bfloat16", and val.dtype == "uint8".
- The function clamps the adjusted exponent to the 8-bit range before assembling the bfloat16 bit pattern.
"""
Notes:
- Preconditions are enforced via assertions: nbit == 4, dtype == "bfloat16", and val.dtype == "uint8".
- The function clamps the adjusted exponent to the 8-bit range before assembling the bfloat16 bit pattern.
"""
assert nbit == 4
assert dtype == "bfloat16"
assert val.dtype == "uint8"
......@@ -43,9 +42,10 @@ def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale
# To handle the overflow, we may use the min function to limit the exponential part to 8 bits
# e_bf16 = T.min(e_bf16 + scale, tir.const((1 << 8) - 1, "uint16"))
m_f4 = f4 & tir.const(1, "uint16")
val_bf16 = tir.reinterpret("bfloat16",
((((s << tir.const(8, "uint16")) | e_bf16) << tir.const(7, "uint16"))
| (m_f4 << tir.const(6, "uint16"))).astype("uint16"))
val_bf16 = tir.reinterpret(
"bfloat16",
((((s << tir.const(8, "uint16")) | e_bf16) << tir.const(7, "uint16")) | (m_f4 << tir.const(6, "uint16"))).astype("uint16"),
)
return val_bf16
......@@ -65,6 +65,7 @@ def get_configs():
List[dict]: A list of configuration dictionaries covering all combinations.
"""
import itertools
iter_params = dict(
block_M=[64, 128, 256],
block_N=[64, 128, 256],
......@@ -73,67 +74,71 @@ def get_configs():
threads=[128, 256, 512],
split=[1, 2],
)
return [{
k: v for k, v in zip(iter_params, values)
} for values in itertools.product(*iter_params.values())]
@tilelang.autotune(configs=get_configs(),)
@tilelang.jit(out_idx=[-1],)
def matmul(M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
source_format='uint',
num_bits=4,
scale_size=32,
fast_dequant=True,
with_bias=False,
block_M=256,
block_N=128,
block_K=128,
num_stages=2,
threads=256,
split=1):
return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())]
@tilelang.autotune(
configs=get_configs(),
)
@tilelang.jit(
out_idx=[-1],
)
def matmul(
M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
source_format="uint",
num_bits=4,
scale_size=32,
fast_dequant=True,
with_bias=False,
block_M=256,
block_N=128,
block_K=128,
num_stages=2,
threads=256,
split=1,
):
"""
Construct and return a tiled matrix-multiply TIR kernel that multiplies A (shape MxK) by a quantized B (shape Nx(QK)) and writes an MxN output in out_dtype.
The generated kernel accepts:
- A: dense matrix with element type `in_dtype`.
- B: packed quantized matrix stored as uint8 with `num_bits` bits per element (QK = K / (8/num_bits)).
- Scale: per-block scale/exponent information used to dequantize B.
The kernel dequantizes B to a working floating format (out_dtype/accum_dtype) using one of two paths:
- fast_dequant (True): uses an external, hardware/implementation-specific intrinsic group (twiddling) for batch dequantization.
- fast_dequant (False): uses a simple elementwise dequantization helper.
Parameters:
M, N, K (int): matrix dimensions (A is MxK, result is MxN). K must be divisible by (block_K * split).
in_dtype (str): element type of A (e.g., "fp4" in this file).
out_dtype (str): output tensor element type (e.g., "bfloat16").
accum_dtype (str): accumulation type used for the inner GEMM.
source_format (str, optional): format string passed to intrinsic selector (default "uint").
num_bits (int, optional): number of bits per quantized element in B (default 4).
scale_size (int, optional): number of elements grouped per scale entry (default 32).
fast_dequant (bool, optional): choose the fast intrinsic dequantization path when available (default True).
block_M, block_N, block_K (int, optional): tile sizes for M, N, and K dimensions (defaults 256, 128, 128).
num_stages (int, optional): pipelining stages for K loop (default 2).
threads (int, optional): threads per block used by the kernel (default 256).
split (int, optional): split factor along K used by the scheduler (default 1).
with_bias (bool, optional): whether to add Bias to the output (default False).
Returns:
A T.prim_func implementing the tiled, pipelined GEMM that:
- loads tiled blocks of A and packed B to shared memory,
- dequantizes B via the chosen path into a shared dequantized tile,
- performs a tiled GEMM accumulating into local fragments,
- writes the final MxN block to the global output tensor.
Construct and return a tiled matrix-multiply TIR kernel that multiplies A (shape MxK) by a quantized B (shape Nx(QK)) and writes an MxN output in out_dtype.
Notes:
- The function queries an intrinsic group to obtain a fast dequantization implementation when fast_dequant is enabled; that intrinsic must supply a valid C source and function name.
- The kernel layout uses swizzled shared-memory layouts for A, B, and the shared C tile.
- An assertion enforces that K % (block_K * split) == 0.
The generated kernel accepts:
- A: dense matrix with element type `in_dtype`.
- B: packed quantized matrix stored as uint8 with `num_bits` bits per element (QK = K / (8/num_bits)).
- Scale: per-block scale/exponent information used to dequantize B.
The kernel dequantizes B to a working floating format (out_dtype/accum_dtype) using one of two paths:
- fast_dequant (True): uses an external, hardware/implementation-specific intrinsic group (twiddling) for batch dequantization.
- fast_dequant (False): uses a simple elementwise dequantization helper.
Parameters:
M, N, K (int): matrix dimensions (A is MxK, result is MxN). K must be divisible by (block_K * split).
in_dtype (str): element type of A (e.g., "fp4" in this file).
out_dtype (str): output tensor element type (e.g., "bfloat16").
accum_dtype (str): accumulation type used for the inner GEMM.
source_format (str, optional): format string passed to intrinsic selector (default "uint").
num_bits (int, optional): number of bits per quantized element in B (default 4).
scale_size (int, optional): number of elements grouped per scale entry (default 32).
fast_dequant (bool, optional): choose the fast intrinsic dequantization path when available (default True).
block_M, block_N, block_K (int, optional): tile sizes for M, N, and K dimensions (defaults 256, 128, 128).
num_stages (int, optional): pipelining stages for K loop (default 2).
threads (int, optional): threads per block used by the kernel (default 256).
split (int, optional): split factor along K used by the scheduler (default 1).
with_bias (bool, optional): whether to add Bias to the output (default False).
Returns:
A T.prim_func implementing the tiled, pipelined GEMM that:
- loads tiled blocks of A and packed B to shared memory,
- dequantizes B via the chosen path into a shared dequantized tile,
- performs a tiled GEMM accumulating into local fragments,
- writes the final MxN block to the global output tensor.
Notes:
- The function queries an intrinsic group to obtain a fast dequantization implementation when fast_dequant is enabled; that intrinsic must supply a valid C source and function name.
- The kernel layout uses swizzled shared-memory layouts for A, B, and the shared C tile.
- An assertion enforces that K % (block_K * split) == 0.
"""
num_elems_per_byte = 8 // num_bits
storage_dtype = "uint8"
......@@ -150,6 +155,7 @@ def matmul(M,
assert K % (block_K * split) == 0
from tilelang.quantize import get_mxfp_intrin_group
# fast_dequant_bf16_fp4_twiddling
mxfp_intrin_info = get_mxfp_intrin_group(
out_dtype=in_dtype,
......@@ -252,8 +258,7 @@ def matmul(M,
for v in T.vectorized(0, local_size):
index = i * threads * local_size + tx * local_size + v
B_dequantize_shared[index // block_K,
index % block_K] = B_dequantize_local_thread[v]
B_dequantize_shared[index // block_K, index % block_K] = B_dequantize_local_thread[v]
return fast_dequant_bf16_fp4_twiddling
......@@ -301,33 +306,32 @@ def matmul(M,
B_local[i, j // num_elems_per_byte],
j % num_elems_per_byte,
Scale[
bx * block_N + i, k * block_K // scale_size + j //
scale_size], # Scale is the exponential part, within the representation of uint8
bx * block_N + i, k * block_K // scale_size + j // scale_size
], # Scale is the exponential part, within the representation of uint8
dtype=out_dtype,
) * T.shift_left(
1, (Scale[bx * block_N + i, k * block_K // scale_size + j // scale_size]))
) * T.shift_left(1, (Scale[bx * block_N + i, k * block_K // scale_size + j // scale_size]))
T.copy(B_dequantize_local, B_dequantize_shared)
return simple_dequant_bf16_fp4
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, storage_dtype),
Scale: T.Tensor(Scale_shape, storage_dtype),
Bias: T.Tensor(Bias_shape, out_dtype),
C: T.Tensor((M, N), out_dtype),
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, storage_dtype),
Scale: T.Tensor(Scale_shape, storage_dtype),
Bias: T.Tensor(Bias_shape, out_dtype),
C: T.Tensor((M, N), out_dtype),
):
"""
Tiled, pipelined kernel entry that multiplies A with a quantized B (with per-block Scale) producing C.
Tiled, pipelined kernel entry that multiplies A with a quantized B (with per-block Scale) producing C.
This prim-level kernel implements a blocked, multi-threaded matmul: it loads tiles of A and the packed/quantized B into shared memory, dequantizes B (either via the fast intrinsic twiddling path or the simple per-element path), performs a block GEMM (with B transposed), and writes the accumulated block results into the output tensor C. The kernel allocates shared buffers for A, B, and the dequantized B, and a local fragment for accumulation; it runs over K in pipelined stages and expects the provided shapes and dtypes to match the tiling parameters used to build the function.
This prim-level kernel implements a blocked, multi-threaded matmul: it loads tiles of A and the packed/quantized B into shared memory, dequantizes B (either via the fast intrinsic twiddling path or the simple per-element path), performs a block GEMM (with B transposed), and writes the accumulated block results into the output tensor C. The kernel allocates shared buffers for A, B, and the dequantized B, and a local fragment for accumulation; it runs over K in pipelined stages and expects the provided shapes and dtypes to match the tiling parameters used to build the function.
Parameters are self-descriptive in the signature; notable behaviors:
- B is stored in a compact uint8-packed layout (num_bits per element) and is dequantized using Scale before GEMM.
- The selected dequantization path is controlled by the outer-scope flag `fast_dequant`.
- The GEMM uses transpose_B=True (i.e., multiplies A · B^T after dequantization).
- The function writes results in-place into C.
Parameters are self-descriptive in the signature; notable behaviors:
- B is stored in a compact uint8-packed layout (num_bits per element) and is dequantized using Scale before GEMM.
- The selected dequantization path is controlled by the outer-scope flag `fast_dequant`.
- The GEMM uses transpose_B=True (i.e., multiplies A · B^T after dequantization).
- The function writes results in-place into C.
"""
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
......@@ -337,23 +341,26 @@ def matmul(M,
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
C_shared = T.alloc_shared((block_M, block_N), out_dtype)
T.annotate_layout({
A_shared: tilelang.layout.make_swizzled_layout(A_shared),
B_shared: tilelang.layout.make_swizzled_layout(B_shared),
C_shared: tilelang.layout.make_swizzled_layout(C_shared),
})
T.annotate_layout(
{
A_shared: tilelang.layout.make_swizzled_layout(A_shared),
B_shared: tilelang.layout.make_swizzled_layout(B_shared),
C_shared: tilelang.layout.make_swizzled_layout(C_shared),
}
)
if with_bias:
T.annotate_layout({
Bias_shared: tilelang.layout.make_swizzled_layout(Bias_shared),
})
T.annotate_layout(
{
Bias_shared: tilelang.layout.make_swizzled_layout(Bias_shared),
}
)
if threads == 512:
T.disable_warp_group_reg_alloc()
if with_bias:
T.copy(Bias[by * block_M:(by + 1) * block_M, bx * block_N:(bx + 1) * block_N],
Bias_shared)
T.copy(Bias[by * block_M : (by + 1) * block_M, bx * block_N : (bx + 1) * block_N], Bias_shared)
T.copy(Bias_shared, C_local)
else:
T.clear(C_local)
......@@ -368,7 +375,7 @@ def matmul(M,
T.gemm(A_shared, B_dequantize_shared, C_local, transpose_B=True)
T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M:(by + 1) * block_M, bx * block_N:(bx + 1) * block_N])
T.copy(C_shared, C[by * block_M : (by + 1) * block_M, bx * block_N : (bx + 1) * block_N])
return main
......@@ -389,7 +396,7 @@ def ref_program_twiddling(A, qB, Scale, Bias=None):
"""
dtypeC = "bfloat16"
B = torch_convert_bit_twiddling(qB)
B *= 2**(Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)])
B *= 2 ** (Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)])
C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
C = C.to(torch.__getattribute__(dtypeC))
return C
......@@ -412,7 +419,7 @@ def ref_program_twiddling_with_bias(A, qB, Scale, Bias):
"""
dtypeC = "bfloat16"
B = torch_convert_bit_twiddling(qB)
B *= 2**(Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)])
B *= 2 ** (Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)])
C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + Bias
C = C.to(torch.__getattribute__(dtypeC))
return C
......@@ -436,7 +443,7 @@ def ref_program_simple(A, qB, Scale, Bias=None):
"""
dtypeC = "bfloat16"
B = torch_convert(qB)
B *= 2**(Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)])
B *= 2 ** (Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)])
C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
C = C.to(torch.__getattribute__(dtypeC))
return C
......@@ -464,7 +471,7 @@ def ref_program_simple_with_bias(A, qB, Scale, Bias):
"""
dtypeC = "bfloat16"
B = torch_convert(qB)
B *= 2**(Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)])
B *= 2 ** (Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)])
C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + Bias
C = C.to(torch.__getattribute__(dtypeC))
return C
......@@ -491,16 +498,8 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False,
if tune:
kernel = matmul(
m,
n,
k,
"bfloat16",
"bfloat16",
"float32",
num_bits=4,
scale_size=scale_size,
fast_dequant=fast_dequant,
with_bias=with_bias)
m, n, k, "bfloat16", "bfloat16", "float32", num_bits=4, scale_size=scale_size, fast_dequant=fast_dequant, with_bias=with_bias
)
else:
kernel = matmul(
m,
......@@ -518,7 +517,8 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False,
threads=256,
split=1,
fast_dequant=fast_dequant,
with_bias=with_bias)
with_bias=with_bias,
)
profiler = kernel.get_profiler(tilelang.TensorSupplyType.Auto)
......
......@@ -7,29 +7,28 @@ import torch
from dequantize_utils import torch_convert_bit_twiddling, torch_convert
def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale: tir.PrimExpr,
dtype: str):
def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale: tir.PrimExpr, dtype: str):
"""
Convert a 4-bit field packed in a uint8 into a bfloat16 value, applying an exponent scale.
Convert a 4-bit field packed in a uint8 into a bfloat16 value, applying an exponent scale.
This helper extracts a 4-bit nibble from `val` at byte-nibble position `pos`, interprets its
bits as a sign/exponent/mantissa in the 4-bit custom FP4 layout, adjusts the exponent by
`scale` (clamped to an 8-bit range), and assembles the corresponding bfloat16 representation.
This helper extracts a 4-bit nibble from `val` at byte-nibble position `pos`, interprets its
bits as a sign/exponent/mantissa in the 4-bit custom FP4 layout, adjusts the exponent by
`scale` (clamped to an 8-bit range), and assembles the corresponding bfloat16 representation.
Parameters:
nbit (int): Number of bits in the packed field (must be 4).
val (tir.PrimExpr): Packed input value of dtype `uint8` containing one or more 4-bit fields.
pos (tir.PrimExpr): Index of the nibble within `val` (used to shift/extract the 4-bit field).
scale (tir.PrimExpr): Per-element exponent adjustment added to the extracted exponent (uint-like).
dtype (str): Destination dtype string (must be "bfloat16").
Parameters:
nbit (int): Number of bits in the packed field (must be 4).
val (tir.PrimExpr): Packed input value of dtype `uint8` containing one or more 4-bit fields.
pos (tir.PrimExpr): Index of the nibble within `val` (used to shift/extract the 4-bit field).
scale (tir.PrimExpr): Per-element exponent adjustment added to the extracted exponent (uint-like).
dtype (str): Destination dtype string (must be "bfloat16").
Returns:
tir.PrimExpr: The resulting value reinterpreted as `bfloat16`.
Returns:
tir.PrimExpr: The resulting value reinterpreted as `bfloat16`.
Notes:
- Preconditions are enforced via assertions: nbit == 4, dtype == "bfloat16", and val.dtype == "uint8".
- The function clamps the adjusted exponent to the 8-bit range before assembling the bfloat16 bit pattern.
"""
Notes:
- Preconditions are enforced via assertions: nbit == 4, dtype == "bfloat16", and val.dtype == "uint8".
- The function clamps the adjusted exponent to the 8-bit range before assembling the bfloat16 bit pattern.
"""
assert nbit == 4
assert dtype == "bfloat16"
assert val.dtype == "uint8"
......@@ -43,9 +42,10 @@ def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale
# To handle the overflow, we may use the min function to limit the exponential part to 8 bits
# e_bf16 = T.min(e_bf16 + scale, tir.const((1 << 8) - 1, "uint16"))
m_f4 = f4 & tir.const(1, "uint16")
val_bf16 = tir.reinterpret("bfloat16",
((((s << tir.const(8, "uint16")) | e_bf16) << tir.const(7, "uint16"))
| (m_f4 << tir.const(6, "uint16"))).astype("uint16"))
val_bf16 = tir.reinterpret(
"bfloat16",
((((s << tir.const(8, "uint16")) | e_bf16) << tir.const(7, "uint16")) | (m_f4 << tir.const(6, "uint16"))).astype("uint16"),
)
return val_bf16
......@@ -65,6 +65,7 @@ def get_configs():
List[dict]: A list of configuration dictionaries covering all combinations.
"""
import itertools
iter_params = dict(
block_M=[64, 128, 256],
block_N=[64, 128, 256],
......@@ -73,67 +74,71 @@ def get_configs():
threads=[128, 256, 512],
split=[1, 2],
)
return [{
k: v for k, v in zip(iter_params, values)
} for values in itertools.product(*iter_params.values())]
@tilelang.autotune(configs=get_configs(),)
@tilelang.jit(out_idx=[-1],)
def matmul(M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
source_format='uint',
num_bits=4,
scale_size=32,
fast_dequant=True,
with_bias=False,
block_M=256,
block_N=128,
block_K=128,
num_stages=2,
threads=256,
split=1):
return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())]
@tilelang.autotune(
configs=get_configs(),
)
@tilelang.jit(
out_idx=[-1],
)
def matmul(
M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
source_format="uint",
num_bits=4,
scale_size=32,
fast_dequant=True,
with_bias=False,
block_M=256,
block_N=128,
block_K=128,
num_stages=2,
threads=256,
split=1,
):
"""
Construct and return a tiled matrix-multiply TIR kernel that multiplies A (shape MxK) by a quantized B (shape Nx(QK)) and writes an MxN output in out_dtype.
The generated kernel accepts:
- A: dense matrix with element type `in_dtype`.
- B: packed quantized matrix stored as uint8 with `num_bits` bits per element (QK = K / (8/num_bits)).
- Scale: per-block scale/exponent information used to dequantize B.
The kernel dequantizes B to a working floating format (out_dtype/accum_dtype) using one of two paths:
- fast_dequant (True): uses an external, hardware/implementation-specific intrinsic group (twiddling) for batch dequantization.
- fast_dequant (False): uses a simple elementwise dequantization helper.
Parameters:
M, N, K (int): matrix dimensions (A is MxK, result is MxN). K must be divisible by (block_K * split).
in_dtype (str): element type of A (e.g., "fp4" in this file).
out_dtype (str): output tensor element type (e.g., "bfloat16").
accum_dtype (str): accumulation type used for the inner GEMM.
source_format (str, optional): format string passed to intrinsic selector (default "uint").
num_bits (int, optional): number of bits per quantized element in B (default 4).
scale_size (int, optional): number of elements grouped per scale entry (default 32).
fast_dequant (bool, optional): choose the fast intrinsic dequantization path when available (default True).
block_M, block_N, block_K (int, optional): tile sizes for M, N, and K dimensions (defaults 256, 128, 128).
num_stages (int, optional): pipelining stages for K loop (default 2).
threads (int, optional): threads per block used by the kernel (default 256).
split (int, optional): split factor along K used by the scheduler (default 1).
with_bias (bool, optional): whether to add Bias to the output (default False).
Returns:
A T.prim_func implementing the tiled, pipelined GEMM that:
- loads tiled blocks of A and packed B to shared memory,
- dequantizes B via the chosen path into a shared dequantized tile,
- performs a tiled GEMM accumulating into local fragments,
- writes the final MxN block to the global output tensor.
Construct and return a tiled matrix-multiply TIR kernel that multiplies A (shape MxK) by a quantized B (shape Nx(QK)) and writes an MxN output in out_dtype.
Notes:
- The function queries an intrinsic group to obtain a fast dequantization implementation when fast_dequant is enabled; that intrinsic must supply a valid C source and function name.
- The kernel layout uses swizzled shared-memory layouts for A, B, and the shared C tile.
- An assertion enforces that K % (block_K * split) == 0.
The generated kernel accepts:
- A: dense matrix with element type `in_dtype`.
- B: packed quantized matrix stored as uint8 with `num_bits` bits per element (QK = K / (8/num_bits)).
- Scale: per-block scale/exponent information used to dequantize B.
The kernel dequantizes B to a working floating format (out_dtype/accum_dtype) using one of two paths:
- fast_dequant (True): uses an external, hardware/implementation-specific intrinsic group (twiddling) for batch dequantization.
- fast_dequant (False): uses a simple elementwise dequantization helper.
Parameters:
M, N, K (int): matrix dimensions (A is MxK, result is MxN). K must be divisible by (block_K * split).
in_dtype (str): element type of A (e.g., "fp4" in this file).
out_dtype (str): output tensor element type (e.g., "bfloat16").
accum_dtype (str): accumulation type used for the inner GEMM.
source_format (str, optional): format string passed to intrinsic selector (default "uint").
num_bits (int, optional): number of bits per quantized element in B (default 4).
scale_size (int, optional): number of elements grouped per scale entry (default 32).
fast_dequant (bool, optional): choose the fast intrinsic dequantization path when available (default True).
block_M, block_N, block_K (int, optional): tile sizes for M, N, and K dimensions (defaults 256, 128, 128).
num_stages (int, optional): pipelining stages for K loop (default 2).
threads (int, optional): threads per block used by the kernel (default 256).
split (int, optional): split factor along K used by the scheduler (default 1).
with_bias (bool, optional): whether to add Bias to the output (default False).
Returns:
A T.prim_func implementing the tiled, pipelined GEMM that:
- loads tiled blocks of A and packed B to shared memory,
- dequantizes B via the chosen path into a shared dequantized tile,
- performs a tiled GEMM accumulating into local fragments,
- writes the final MxN block to the global output tensor.
Notes:
- The function queries an intrinsic group to obtain a fast dequantization implementation when fast_dequant is enabled; that intrinsic must supply a valid C source and function name.
- The kernel layout uses swizzled shared-memory layouts for A, B, and the shared C tile.
- An assertion enforces that K % (block_K * split) == 0.
"""
num_elems_per_byte = 8 // num_bits
storage_dtype = "uint8"
......@@ -150,6 +155,7 @@ def matmul(M,
assert K % (block_K * split) == 0
from tilelang.quantize import get_mxfp_intrin_group
# fast_dequant_bf16_fp4_twiddling
mxfp_intrin_info = get_mxfp_intrin_group(
out_dtype=in_dtype,
......@@ -252,8 +258,7 @@ def matmul(M,
for v in T.vectorized(0, local_size):
index = i * threads * local_size + tx * local_size + v
B_dequantize_shared[index // block_K,
index % block_K] = B_dequantize_local_thread[v]
B_dequantize_shared[index // block_K, index % block_K] = B_dequantize_local_thread[v]
return fast_dequant_bf16_fp4_twiddling
......@@ -301,8 +306,8 @@ def matmul(M,
B_local[i, j // num_elems_per_byte],
j % num_elems_per_byte,
Scale_shared[
i, k * block_K // scale_size + j //
scale_size], # Scale is the exponential part, within the representation of uint8
i, k * block_K // scale_size + j // scale_size
], # Scale is the exponential part, within the representation of uint8
dtype=out_dtype,
) * T.shift_left(1, (Scale_shared[i, k * block_K // scale_size + j // scale_size]))
T.copy(B_dequantize_local, B_dequantize_shared)
......@@ -311,22 +316,22 @@ def matmul(M,
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, storage_dtype),
Scale: T.Tensor(Scale_shape, storage_dtype),
Bias: T.Tensor(Bias_shape, out_dtype),
C: T.Tensor((M, N), out_dtype),
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, storage_dtype),
Scale: T.Tensor(Scale_shape, storage_dtype),
Bias: T.Tensor(Bias_shape, out_dtype),
C: T.Tensor((M, N), out_dtype),
):
"""
Tiled, pipelined kernel entry that multiplies A with a quantized B (with per-block Scale) producing C.
Tiled, pipelined kernel entry that multiplies A with a quantized B (with per-block Scale) producing C.
This prim-level kernel implements a blocked, multi-threaded matmul: it loads tiles of A and the packed/quantized B into shared memory, dequantizes B (either via the fast intrinsic twiddling path or the simple per-element path), performs a block GEMM (with B transposed), and writes the accumulated block results into the output tensor C. The kernel allocates shared buffers for A, B, and the dequantized B, and a local fragment for accumulation; it runs over K in pipelined stages and expects the provided shapes and dtypes to match the tiling parameters used to build the function.
This prim-level kernel implements a blocked, multi-threaded matmul: it loads tiles of A and the packed/quantized B into shared memory, dequantizes B (either via the fast intrinsic twiddling path or the simple per-element path), performs a block GEMM (with B transposed), and writes the accumulated block results into the output tensor C. The kernel allocates shared buffers for A, B, and the dequantized B, and a local fragment for accumulation; it runs over K in pipelined stages and expects the provided shapes and dtypes to match the tiling parameters used to build the function.
Parameters are self-descriptive in the signature; notable behaviors:
- B is stored in a compact uint8-packed layout (num_bits per element) and is dequantized using Scale before GEMM.
- The selected dequantization path is controlled by the outer-scope flag `fast_dequant`.
- The GEMM uses transpose_B=True (i.e., multiplies A · B^T after dequantization).
- The function writes results in-place into C.
Parameters are self-descriptive in the signature; notable behaviors:
- B is stored in a compact uint8-packed layout (num_bits per element) and is dequantized using Scale before GEMM.
- The selected dequantization path is controlled by the outer-scope flag `fast_dequant`.
- The GEMM uses transpose_B=True (i.e., multiplies A · B^T after dequantization).
- The function writes results in-place into C.
"""
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
......@@ -339,16 +344,20 @@ def matmul(M,
# May use much more shared memory than necessary
Scale_shared = T.alloc_shared((block_N, K // scale_size), storage_dtype)
T.annotate_layout({
A_shared: tilelang.layout.make_swizzled_layout(A_shared),
B_shared: tilelang.layout.make_swizzled_layout(B_shared),
C_shared: tilelang.layout.make_swizzled_layout(C_shared),
})
T.annotate_layout(
{
A_shared: tilelang.layout.make_swizzled_layout(A_shared),
B_shared: tilelang.layout.make_swizzled_layout(B_shared),
C_shared: tilelang.layout.make_swizzled_layout(C_shared),
}
)
if with_bias:
T.annotate_layout({
Bias_shared: tilelang.layout.make_swizzled_layout(Bias_shared),
})
T.annotate_layout(
{
Bias_shared: tilelang.layout.make_swizzled_layout(Bias_shared),
}
)
if threads == 512:
T.disable_warp_group_reg_alloc()
......@@ -357,26 +366,24 @@ def matmul(M,
# T.copy(Bias[by * block_M:(by + 1) * block_M, bx * block_N:(bx + 1) * block_N],
# Bias_shared)
# T.copy(Bias_shared, C_local)
T.copy(Bias[by * block_M:(by + 1) * block_M, bx * block_N:(bx + 1) * block_N],
C_local)
T.copy(Bias[by * block_M : (by + 1) * block_M, bx * block_N : (bx + 1) * block_N], C_local)
else:
T.clear(C_local)
# Use 1D TMA to load Scale
T.copy(Scale[bx * block_N:(bx + 1) * block_N, :], Scale_shared)
T.copy(Scale[bx * block_N : (bx + 1) * block_N, :], Scale_shared)
for k in T.Pipelined(K // block_K, num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared)
if fast_dequant:
get_fast_dequant_twiddling_func()(B_shared, B_dequantize_shared, Scale_shared,
k)
get_fast_dequant_twiddling_func()(B_shared, B_dequantize_shared, Scale_shared, k)
else:
get_simple_dequant_func()(B_shared, B_dequantize_shared, Scale_shared, k)
T.gemm(A_shared, B_dequantize_shared, C_local, transpose_B=True)
T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M:(by + 1) * block_M, bx * block_N:(bx + 1) * block_N])
T.copy(C_shared, C[by * block_M : (by + 1) * block_M, bx * block_N : (bx + 1) * block_N])
return main
......@@ -399,7 +406,7 @@ def ref_program_twiddling(A, qB, Scale, Bias=None):
B = torch_convert_bit_twiddling(qB)
for i in range(B.shape[0]):
for j in range(B.shape[1]):
B[i][j] = B[i][j] * (2**(Scale[i][j // 32]))
B[i][j] = B[i][j] * (2 ** (Scale[i][j // 32]))
C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
C = C.to(torch.__getattribute__(dtypeC))
return C
......@@ -424,7 +431,7 @@ def ref_program_twiddling_with_bias(A, qB, Scale, Bias):
B = torch_convert_bit_twiddling(qB)
for i in range(B.shape[0]):
for j in range(B.shape[1]):
B[i][j] = B[i][j] * (2**(Scale[i][j // 32]))
B[i][j] = B[i][j] * (2 ** (Scale[i][j // 32]))
C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + Bias
C = C.to(torch.__getattribute__(dtypeC))
return C
......@@ -450,7 +457,7 @@ def ref_program_simple(A, qB, Scale, Bias=None):
B = torch_convert(qB)
for i in range(B.shape[0]):
for j in range(B.shape[1]):
B[i][j] = B[i][j] * (2**(Scale[i][j // 32]))
B[i][j] = B[i][j] * (2 ** (Scale[i][j // 32]))
C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
C = C.to(torch.__getattribute__(dtypeC))
return C
......@@ -480,7 +487,7 @@ def ref_program_simple_with_bias(A, qB, Scale, Bias):
B = torch_convert(qB)
for i in range(B.shape[0]):
for j in range(B.shape[1]):
B[i][j] = B[i][j] * (2**(Scale[i][j // 32]))
B[i][j] = B[i][j] * (2 ** (Scale[i][j // 32]))
C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + Bias
C = C.to(torch.__getattribute__(dtypeC))
return C
......@@ -507,16 +514,8 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False,
if tune:
kernel = matmul(
m,
n,
k,
"bfloat16",
"bfloat16",
"float32",
num_bits=4,
scale_size=scale_size,
fast_dequant=fast_dequant,
with_bias=with_bias)
m, n, k, "bfloat16", "bfloat16", "float32", num_bits=4, scale_size=scale_size, fast_dequant=fast_dequant, with_bias=with_bias
)
else:
kernel = matmul(
m,
......@@ -534,7 +533,8 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False,
threads=256,
split=1,
fast_dequant=fast_dequant,
with_bias=with_bias)
with_bias=with_bias,
)
profiler = kernel.get_profiler(tilelang.TensorSupplyType.Auto)
......
......@@ -24,6 +24,7 @@ def matmul(
num_bits=4,
):
from tilelang.quantize import _tir_packed_to_unsigned_convert
num_elems_per_byte = 8 // num_bits
storage_dtype = "int8"
storage_nbit = int("".join(c for c in storage_dtype if c.isdigit()))
......@@ -39,9 +40,9 @@ def matmul(
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, storage_dtype),
C: T.Tensor((M, N), out_dtype),
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, storage_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
......@@ -58,21 +59,19 @@ def matmul(
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared)
for i in T.serial(block_N * block_K // num_elems_per_byte //
(threads * local_size_compressed)):
for i in T.serial(block_N * block_K // num_elems_per_byte // (threads * local_size_compressed)):
for v in T.vectorized(0, local_size_compressed):
index = i * threads * local_size_compressed + tx * local_size_compressed + v
vi = index // (block_K // num_elems_per_byte)
vj = index % (block_K // num_elems_per_byte)
B_local[v] = B_shared[vi, vj]
for v in T.serial(0, local_size):
B_dequantize_local[v] = _tir_packed_to_unsigned_convert(
storage_type, storage_nbit)(
num_bits,
B_local[v // num_elems_per_byte],
v % num_elems_per_byte,
dtype=in_dtype,
)
B_dequantize_local[v] = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)(
num_bits,
B_local[v // num_elems_per_byte],
v % num_elems_per_byte,
dtype=in_dtype,
)
for v in T.vectorized(0, local_size):
index = i * threads * local_size + tx * local_size + v
vi = index // block_K
......@@ -121,9 +120,7 @@ def run_gemm(
def ref_program(A, qB):
import torch
B = (
torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4,
dtype=torch.half).to(torch.half).to(A.device))
B = torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4, dtype=torch.half).to(torch.half).to(A.device)
for i in range(B.shape[0]):
for j in range(B.shape[1]):
B[i][j] = ((qB[i][j // 2] >> (4 * (j % 2))) & 0xF).to(torch.half)
......@@ -146,9 +143,11 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(
):
from tilelang.intrinsics.mma_layout import make_mma_swizzle_layout as make_swizzle_layout
from tilelang.intrinsics.mma_macro_generator import (
TensorCoreIntrinEmitterWithLadderTransform,)
TensorCoreIntrinEmitterWithLadderTransform,
)
from bitblas.gpu.intrin.lop3 import decode_i4_to_f16
assert in_dtype in [
"float16",
"int8",
......@@ -192,8 +191,7 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(
pad_factor = 8
A_shape = (M, K)
B_shape = (N // micro_size_y, K // micro_size_k, micro_size_y,
micro_size_k // num_elems_per_byte)
B_shape = (N // micro_size_y, K // micro_size_k, micro_size_y, micro_size_k // num_elems_per_byte)
A_shared_shape = (block_M, (block_K + pad_factor) if apply_pad_a else block_K)
B_shared_shape = (
block_N // micro_size_y,
......@@ -228,7 +226,8 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(
chunk=chunk,
reduce_k=reduce_k,
transform_kind_b=transform_b,
num_elems_per_byte=num_elems_per_byte)
num_elems_per_byte=num_elems_per_byte,
)
vec_load_qb = 16
if block_N * (block_K // reduce_k) // num_elems_per_byte // threads < vec_load_qb:
......@@ -236,14 +235,11 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, storage_dtype),
C: T.Tensor((M, N), out_dtype),
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, storage_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads,
prelude=decode_i4_to_f16) as (bx, by):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads, prelude=decode_i4_to_f16) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, storage_dtype, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
......@@ -255,40 +251,36 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(
thread_binding = T.get_thread_binding(0)
rk = T.get_thread_binding(1)
T.annotate_layout({
A_shared: make_swizzle_layout(A_shared),
})
T.annotate_layout(
{
A_shared: make_swizzle_layout(A_shared),
}
)
T.use_swizzle(panel_size=10)
T.clear(C_local)
for ko in T.Pipelined((K // block_K), num_stages=stage):
# Load A into shared memory
for i, k in T.Parallel(block_M, (block_K // reduce_k)):
vk = rk * (block_K // reduce_k) + k
A_shared[i, vk] = A[by * block_M + i, ko * block_K + vk]
# TODO(lei): Layout Inference Pass is not efficient to handle the four dims int8 load
for i in T.serial(block_N * (block_K // reduce_k) // num_elems_per_byte //
(threads * vec_load_qb)):
for i in T.serial(block_N * (block_K // reduce_k) // num_elems_per_byte // (threads * vec_load_qb)):
for v in T.vectorized(0, vec_load_qb):
t = thread_binding
idx = i * threads * vec_load_qb * reduce_k + rk * threads * vec_load_qb + t * vec_load_qb + v
vkk = idx % (micro_size_k // num_elems_per_byte)
vjj = (idx // (micro_size_k // num_elems_per_byte)) % micro_size_y
vk = (idx // (micro_size_k // num_elems_per_byte) // micro_size_y) % (
block_K // micro_size_k)
vj = (idx // (micro_size_k // num_elems_per_byte) // micro_size_y //
(block_K // micro_size_k)) % (
block_N // micro_size_y)
B_shared[vj, vk, vjj,
vkk] = B[bx * (block_N // micro_size_y) + vj,
ko * (block_K // micro_size_k) + vk, vjj, vkk]
vk = (idx // (micro_size_k // num_elems_per_byte) // micro_size_y) % (block_K // micro_size_k)
vj = (idx // (micro_size_k // num_elems_per_byte) // micro_size_y // (block_K // micro_size_k)) % (
block_N // micro_size_y
)
B_shared[vj, vk, vjj, vkk] = B[bx * (block_N // micro_size_y) + vj, ko * (block_K // micro_size_k) + vk, vjj, vkk]
for ki in T.serial(0, (block_K // (micro_size_k * reduce_k))):
# Load A into fragment
mma_emitter.ldmatrix_a(
A_local,
......@@ -307,9 +299,13 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(
for j in T.serial(warp_cols):
local_size_b = mma_emitter.local_size_b
T.call_extern('handle', 'decode_i4u_to_f16',
T.address_of(B_local[j * local_size_b // num_elems_per_byte]),
T.address_of(B_dequantize_local[j * local_size_b]), 8)
T.call_extern(
"handle",
"decode_i4u_to_f16",
T.address_of(B_local[j * local_size_b // num_elems_per_byte]),
T.address_of(B_dequantize_local[j * local_size_b]),
8,
)
mma_emitter.mma(A_local, B_dequantize_local, C_local)
......@@ -328,7 +324,8 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(
reduced_accum_res[0],
rk,
dtype="handle",
))
)
)
if rk == 0:
C_local[n] = reduced_accum_res[0]
......@@ -340,9 +337,9 @@ def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(
for i, j in T.Parallel(block_M, (block_N // reduce_k)):
vj = rk * (block_N // reduce_k) + j
C[by * block_M + i,
bx * block_N + vj] = C_shared[i // micro_size_x, vj // micro_size_y,
i % micro_size_x, vj % micro_size_y]
C[by * block_M + i, bx * block_N + vj] = C_shared[
i // micro_size_x, vj // micro_size_y, i % micro_size_x, vj % micro_size_y
]
return main
......@@ -357,8 +354,8 @@ def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correct
transform_b,
):
import bitblas
matmul = tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(
M, N, K, in_dtype, out_dtype, accum_dtype, transform_b)
matmul = tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(M, N, K, in_dtype, out_dtype, accum_dtype, transform_b)
kernel = tilelang.compile(matmul, out_idx=[2])
src_code = kernel.get_kernel_source()
......@@ -371,8 +368,7 @@ def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correct
storage_dtype = "int8"
A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype))
qB = torch.randint(
0, 127, (N, K // num_elems_per_byte), device="cuda", dtype=getattr(torch, storage_dtype))
qB = torch.randint(0, 127, (N, K // num_elems_per_byte), device="cuda", dtype=getattr(torch, storage_dtype))
C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, accum_dtype))
ladder_permutate_config = bitblas.ops.LadderPermutateConfig(
......@@ -407,9 +403,7 @@ def assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correct
# Ensure that the latency is not None
assert latency is not None
B = (
torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4,
dtype=torch.half).to(torch.half).to(A.device))
B = torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4, dtype=torch.half).to(torch.half).to(A.device)
for i in range(B.shape[0]):
for j in range(B.shape[1]):
B[i][j] = ((qB[i][j // 2] >> (4 * (j % 2))) & 0xF).to(torch.half)
......@@ -429,8 +423,7 @@ def test_run_dequantize_gemm():
@tilelang.testing.requires_package("bitblas")
def test_assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4():
assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correctness(
256, 1024, 512, "float16", "float16", "float16", 3)
assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correctness(256, 1024, 512, "float16", "float16", "float16", 3)
def main():
......
......@@ -21,18 +21,17 @@ def _tir_u8_to_f4_to_f16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype:
e_f16 = e_f4 + tir.const(14, "uint16")
m_f4 = f4 & tir.const(1, "uint16")
m_f16 = m_f4
val_f16 = tir.reinterpret("float16",
((e_f16 | (s << tir.const(5, "uint16"))) << tir.const(10, "uint16")
| m_f16 << tir.const(9, "uint16")).astype("uint16"))
val_f16 = tir.reinterpret(
"float16", ((e_f16 | (s << tir.const(5, "uint16"))) << tir.const(10, "uint16") | m_f16 << tir.const(9, "uint16")).astype("uint16")
)
# return tir.Select(e_f4 == tir.const(0, "uint32"), tir.const(0, "float16"), val_f16)
return val_f16
def torch_convert(tensor):
def print_bit(name, val):
val_cpu = val.cpu().item()
binary_repr = f'{val_cpu:032b}'
binary_repr = f"{val_cpu:032b}"
print(name, binary_repr)
def _convert(val, pos):
......@@ -68,8 +67,8 @@ def test_convert(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128):
@T.prim_func
def main(
B: T.Tensor(B_shape, storage_dtype),
C: T.Tensor((N, K), in_dtype),
B: T.Tensor(B_shape, storage_dtype),
C: T.Tensor((N, K), in_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx):
B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
......@@ -118,19 +117,11 @@ def get_configs():
splits = [1]
_configs = list(itertools.product(block_M, block_N, block_K, num_stages, threads, splits))
configs = [{
'block_M': c[0],
'block_N': c[1],
'block_K': c[2],
'num_stages': c[3],
'threads': c[4],
'split': c[5]
} for c in _configs]
configs = [{"block_M": c[0], "block_N": c[1], "block_K": c[2], "num_stages": c[3], "threads": c[4], "split": c[5]} for c in _configs]
return configs
def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False):
@tilelang.jit(out_idx=[2])
def kernel_func(block_M, block_N, block_K, num_stages, threads, split=1):
num_elems_per_byte = 8 // num_bits
......@@ -145,17 +136,12 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False):
@T.prim_func
def main_split(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, storage_dtype),
Ct: T.Tensor((N, M), out_dtype),
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, storage_dtype),
Ct: T.Tensor((N, M), out_dtype),
):
SplitC = T.alloc_buffer([
split, (N + block_N - 1) // block_N * block_N,
(M + block_M - 1) // block_M * block_M
], out_dtype)
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), split,
threads=threads) as (bx, by, bz):
SplitC = T.alloc_buffer([split, (N + block_N - 1) // block_N * block_N, (M + block_M - 1) // block_M * block_M], out_dtype)
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), split, threads=threads) as (bx, by, bz):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
......@@ -164,10 +150,12 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False):
Ct_local = T.alloc_fragment((block_N, block_M), accum_dtype)
Ct_shared = T.alloc_shared((block_N, block_M), out_dtype)
T.annotate_layout({
B_shared: tilelang.layout.make_swizzled_layout(B_shared),
Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared),
})
T.annotate_layout(
{
B_shared: tilelang.layout.make_swizzled_layout(B_shared),
Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared),
}
)
T.clear(Ct_local)
for k in T.Pipelined(K // (block_K * split), num_stages=num_stages):
......@@ -183,8 +171,7 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False):
)
T.copy(B_dequantize_local, B_dequantize_prev_local)
T.gemm(B_dequantize_prev_local, A_shared, Ct_local, transpose_B=True)
T.copy(Ct_local, SplitC[bz, bx * block_N:(bx + 1) * block_N,
by * block_M:(by + 1) * block_M])
T.copy(Ct_local, SplitC[bz, bx * block_N : (bx + 1) * block_N, by * block_M : (by + 1) * block_M])
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M)) as (bx, by):
acc = T.alloc_fragment((block_N, block_M), out_dtype)
T.clear(acc)
......@@ -195,12 +182,11 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False):
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, storage_dtype),
Ct: T.Tensor((N, M), out_dtype),
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, storage_dtype),
Ct: T.Tensor((N, M), out_dtype),
):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
......@@ -209,10 +195,12 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False):
Ct_local = T.alloc_fragment((block_N, block_M), accum_dtype)
Ct_shared = T.alloc_shared((block_N, block_M), out_dtype)
T.annotate_layout({
B_shared: tilelang.layout.make_swizzled_layout(B_shared),
Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared),
})
T.annotate_layout(
{
B_shared: tilelang.layout.make_swizzled_layout(B_shared),
Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared),
}
)
T.clear(Ct_local)
for k in T.Pipelined(K // block_K, num_stages=num_stages):
......@@ -229,8 +217,7 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False):
T.copy(B_dequantize_local, B_dequantize_prev_local)
T.gemm(B_dequantize_prev_local, A_shared, Ct_local, transpose_B=True)
T.copy(Ct_local, Ct_shared)
T.copy(Ct_shared, Ct[bx * block_N:(bx + 1) * block_N,
by * block_M:(by + 1) * block_M])
T.copy(Ct_shared, Ct[bx * block_N : (bx + 1) * block_N, by * block_M : (by + 1) * block_M])
if split == 1:
return main
......@@ -241,12 +228,7 @@ def matmul(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False):
@autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[2])
def kernel(block_M=None,
block_N=None,
block_K=None,
num_stages=None,
threads=None,
split=None):
def kernel(block_M=None, block_N=None, block_K=None, num_stages=None, threads=None, split=None):
return kernel_func(block_M, block_N, block_K, num_stages, threads, split).prim_func
return kernel()
......@@ -269,10 +251,10 @@ def ref_program(A, qB):
def main(m=256, n=256, k=256, tune=False):
total_flops = 2 * m * n * k
if (not tune):
kernel = matmul(
m, n, k, "float16", "float16", "float32", num_bits=4, tune=tune)(
block_M=128, block_N=128, block_K=128, num_stages=2, threads=256, split=1)
if not tune:
kernel = matmul(m, n, k, "float16", "float16", "float32", num_bits=4, tune=tune)(
block_M=128, block_N=128, block_K=128, num_stages=2, threads=256, split=1
)
profiler = kernel.get_profiler(tilelang.TensorSupplyType.Integer)
profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
print("All checks pass.")
......@@ -293,10 +275,10 @@ def main(m=256, n=256, k=256, tune=False):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--m', type=int, default=256, help='M')
parser.add_argument('--n', type=int, default=256, help='N')
parser.add_argument('--k', type=int, default=256, help='K')
parser.add_argument('--tune', action='store_true', help='tune configs')
parser.add_argument("--m", type=int, default=256, help="M")
parser.add_argument("--n", type=int, default=256, help="N")
parser.add_argument("--k", type=int, default=256, help="K")
parser.add_argument("--tune", action="store_true", help="tune configs")
args = parser.parse_args()
M, N, K = args.m, args.n, args.k
main(M, N, K, args.tune)
......@@ -42,8 +42,8 @@ def _convert_test(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128):
@T.prim_func
def main(
B: T.Tensor(B_shape, storage_dtype),
C: T.Tensor((N, K), in_dtype),
B: T.Tensor(B_shape, storage_dtype),
C: T.Tensor((N, K), in_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx):
B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
......@@ -66,13 +66,12 @@ def _convert_test(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128):
def torch_convert(tensor):
def _convert(val, pos):
assert val.dtype == torch.uint8
val = val.view(torch.int8)
mask = (1 << 4) - 1
i4_shifted = ((val >> (pos * 4)) & mask)
i4 = ((i4_shifted << 4) >> 4)
i4_shifted = (val >> (pos * 4)) & mask
i4 = (i4_shifted << 4) >> 4
return i4.view(torch.int8)
......@@ -94,7 +93,6 @@ def ref_program(A, qB):
def matmul_int8xint4(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune=False):
@tilelang.jit(out_idx=[2])
def kernel_func(block_M, block_N, block_K, num_stages, threads):
num_elems_per_byte = 8 // num_bits
......@@ -109,12 +107,11 @@ def matmul_int8xint4(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, storage_dtype),
Ct: T.Tensor((N, M), out_dtype),
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, storage_dtype),
Ct: T.Tensor((N, M), out_dtype),
):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
......@@ -123,10 +120,12 @@ def matmul_int8xint4(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune
Ct_local = T.alloc_fragment((block_N, block_M), accum_dtype)
Ct_shared = T.alloc_shared((block_N, block_M), out_dtype)
T.annotate_layout({
B_shared: tilelang.layout.make_swizzled_layout(B_shared),
Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared),
})
T.annotate_layout(
{
B_shared: tilelang.layout.make_swizzled_layout(B_shared),
Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared),
}
)
T.clear(Ct_local)
for k in T.Pipelined(K // block_K, num_stages=num_stages):
......@@ -143,8 +142,7 @@ def matmul_int8xint4(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune
T.copy(B_dequantize_local, B_dequantize_prev_local)
T.gemm(B_dequantize_prev_local, A_shared, Ct_local, transpose_B=True)
T.copy(Ct_local, Ct_shared)
T.copy(Ct_shared, Ct[bx * block_N:(bx + 1) * block_N,
by * block_M:(by + 1) * block_M])
T.copy(Ct_shared, Ct[bx * block_N : (bx + 1) * block_N, by * block_M : (by + 1) * block_M])
return main
......@@ -167,10 +165,10 @@ def matmul_int8xint4(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits=4, tune
def main(m=128, n=256, k=256, tune=False):
total_flops = 2 * m * n * k
if (not tune):
kernel = matmul_int8xint4(
m, n, k, "int8", "int32", "int32", num_bits=4, tune=tune)(
block_M=32, block_N=32, block_K=128, num_stages=1, threads=128)
if not tune:
kernel = matmul_int8xint4(m, n, k, "int8", "int32", "int32", num_bits=4, tune=tune)(
block_M=32, block_N=32, block_K=128, num_stages=1, threads=128
)
profiler = kernel.get_profiler()
profiler.assert_allclose(ref_program, rtol=1e-2, atol=1e-2)
print("All checks pass.")
......
......@@ -4,7 +4,8 @@ from typing import Optional, Callable, Any
import torch
from tilelang import DataType
from tilelang.quantize import (
_tir_packed_int_to_int_convert,)
_tir_packed_int_to_int_convert,
)
@tilelang.jit
......@@ -26,11 +27,10 @@ def dequantize_gemv(
group_size: int = -1,
with_scaling: bool = False,
) -> Callable[..., Any]:
assert n_partition is not None, "n_partition must be provided"
assert reduce_thread is not None, (
"reduce_thread must be provided currently, as related bitblas.gpu.gemv.GEMV"
"sch_outer_reduction_with_config is not implemented")
"reduce_thread must be provided currently, as related bitblas.gpu.gemv.GEMVsch_outer_reduction_with_config is not implemented"
)
assert trans_A is False, "Dequantize only implement for trans_A=False currently"
assert trans_B is True, "Dequantize only implement for trans_B=TRue currently"
......@@ -81,12 +81,12 @@ def dequantize_gemv(
C: T.Tensor[C_shape, out_dtype],
):
with T.Kernel(
T.ceildiv(N, n_partition),
M,
threads=(reduce_thread, n_partition),
T.ceildiv(N, n_partition),
M,
threads=(reduce_thread, n_partition),
) as (
bx,
by,
bx,
by,
):
A_local = T.alloc_local((micro_size_k,), in_dtype)
B_quant_local = T.alloc_local([micro_size_k_compressed], storage_dtype)
......@@ -107,8 +107,7 @@ def dequantize_gemv(
for v in T.vectorized(micro_size_k_compressed):
B_quant_local[v] = B[
bx * n_partition + ni,
ko * (reduce_thread * micro_size_k_compressed) +
kr * micro_size_k_compressed + v,
ko * (reduce_thread * micro_size_k_compressed) + kr * micro_size_k_compressed + v,
]
if fast_decoding:
......@@ -120,10 +119,9 @@ def dequantize_gemv(
)
else:
for ki in T.serial(micro_size_k):
B_dequantize_local[ki] = _tir_packed_int_to_int_convert(
storage_type,
storage_nbit)(num_bits, B_quant_local[ki // num_elems_per_byte],
ki % num_elems_per_byte, in_dtype)
B_dequantize_local[ki] = _tir_packed_int_to_int_convert(storage_type, storage_nbit)(
num_bits, B_quant_local[ki // num_elems_per_byte], ki % num_elems_per_byte, in_dtype
)
if use_dp4a:
for ki in T.serial(micro_size_k // dp4a_size):
......@@ -137,9 +135,9 @@ def dequantize_gemv(
accum_res[0] += A_local[ki] * B_dequantize_local[ki]
with T.attr(
T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]),
"reduce_scope",
T.reinterpret(T.uint64(0), dtype="handle"),
T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]),
"reduce_scope",
T.reinterpret(T.uint64(0), dtype="handle"),
):
T.evaluate(
T.tvm_thread_allreduce(
......@@ -149,7 +147,8 @@ def dequantize_gemv(
reduced_accum_res[0],
kr,
dtype="handle",
))
)
)
if kr == 0:
C[by, bx * n_partition + ni] = reduced_accum_res[0]
......@@ -174,26 +173,39 @@ def main() -> None:
group_size = -1
with_scaling = False
kernel = dequantize_gemv(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits, storage_dtype,
source_format, n_partition, reduce_thread, fast_decoding, trans_A,
trans_B, group_size, with_scaling)
kernel = dequantize_gemv(
M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
num_bits,
storage_dtype,
source_format,
n_partition,
reduce_thread,
fast_decoding,
trans_A,
trans_B,
group_size,
with_scaling,
)
storage_nbit = int("".join(c for c in storage_dtype if c.isdigit()))
num_elems_per_byte = storage_nbit // num_bits
A = torch.rand(M, K, dtype=getattr(torch, in_dtype)).cuda()
qB = torch.randint(
0, 127, (N, K // num_elems_per_byte), dtype=getattr(torch, storage_dtype)).cuda()
qB = torch.randint(0, 127, (N, K // num_elems_per_byte), dtype=getattr(torch, storage_dtype)).cuda()
C = torch.zeros(M, N, dtype=getattr(torch, accum_dtype)).cuda()
if fast_decoding:
from tilelang.quantize.utils import interleave_weight
qB = interleave_weight(qB, num_bits, in_dtype)
kernel(A, qB, C)
# int4 reference
B = (
torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4,
dtype=torch.half).to(torch.half).to(A.device))
B = torch.zeros(qB.shape[0], qB.shape[1] * 8 // 4, dtype=torch.half).to(torch.half).to(A.device)
for j in range(B.shape[1]):
B[:, j] = ((qB[:, j // 2] >> (4 * (j % 2))) & 0xF).to(torch.half)
......
......@@ -25,6 +25,7 @@ def get_configs():
List[dict]: A list of configuration dictionaries covering all combinations.
"""
import itertools
iter_params = dict(
block_M=[128],
block_N=[64, 128, 256],
......@@ -33,33 +34,33 @@ def get_configs():
threads=[128, 256, 512],
split=[1],
)
return [{
k: v for k, v in zip(iter_params, values)
} for values in itertools.product(*iter_params.values())]
return [{k: v for k, v in zip(iter_params, values)} for values in itertools.product(*iter_params.values())]
@tilelang.autotune(configs=get_configs())
@tilelang.jit(out_idx=[-1])
def matmul(M,
N,
K,
topk,
E,
padding_M,
in_dtype,
out_dtype,
accum_dtype,
source_format='uint',
num_bits=4,
scale_size=32,
fast_dequant=True,
with_bias=False,
block_M=128,
block_N=256,
block_K=128,
num_stages=2,
threads=256,
split=1):
def matmul(
M,
N,
K,
topk,
E,
padding_M,
in_dtype,
out_dtype,
accum_dtype,
source_format="uint",
num_bits=4,
scale_size=32,
fast_dequant=True,
with_bias=False,
block_M=128,
block_N=256,
block_K=128,
num_stages=2,
threads=256,
split=1,
):
"""
Construct and return a grouped (Mixture-of-Experts) matrix-multiply TIR kernel that multiplies A (shape MxK) by a quantized, expert-grouped B (shape ExNxQK) and writes an output of shape (M, topk, N) in out_dtype.
......@@ -115,11 +116,12 @@ def matmul(M,
Block_QK = block_K // num_elems_per_byte
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, Block_QK)
Bias_shared_shape = (block_N)
Bias_shared_shape = block_N
B_dequantize_shared_shape = (block_N, block_K)
assert K % (block_K * split) == 0
from tilelang.quantize import get_mxfp_intrin_group
# fast_dequant_bf16_fp4_twiddling
mxfp_intrin_info = get_mxfp_intrin_group(
out_dtype=in_dtype,
......@@ -221,19 +223,16 @@ def matmul(M,
for v in T.vectorized(0, local_size):
index = i * threads * local_size + tx * local_size + v
B_dequantize_shared[index // block_K,
index % block_K] = B_dequantize_local_thread[v]
B_dequantize_shared[index // block_K, index % block_K] = B_dequantize_local_thread[v]
return fast_dequant_bf16_fp4_twiddling
def get_simple_dequant_func(in_dtype="fp4", out_dtype="bfloat16"):
assert in_dtype in ["fp4"]
assert out_dtype in ["bfloat16"]
@T.macro
def simple_dequant_bf16_fp4(B_shared, B_dequantize_shared, Scale_shared, k):
B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, out_dtype)
......@@ -244,8 +243,8 @@ def matmul(M,
B_local[i, j // num_elems_per_byte],
j % num_elems_per_byte,
Scale_shared[
i, k * block_K // scale_size + j //
scale_size], # Scale is the exponential part, within the representation of uint8
i, k * block_K // scale_size + j // scale_size
], # Scale is the exponential part, within the representation of uint8
dtype=out_dtype,
) * T.shift_left(1, (Scale_shared[i, k * block_K // scale_size + j // scale_size]))
T.copy(B_dequantize_local, B_dequantize_shared)
......@@ -254,19 +253,17 @@ def matmul(M,
@T.prim_func
def main(
A: T.Tensor((M, K), in_dtype),
B: T.Tensor((E, N, QK), storage_dtype),
Scale: T.Tensor((E, N, K // scale_size), storage_dtype),
Bias: T.Tensor((E, N), out_dtype),
# Add fusedmoe tensors
topk_weights: T.Tensor((M * topk), out_dtype),
sorted_token_ids: T.Tensor((padding_M), "int32"),
expert_ids: T.Tensor((padding_M // block_M), "int32"),
C: T.Tensor((M, topk, N), out_dtype),
A: T.Tensor((M, K), in_dtype),
B: T.Tensor((E, N, QK), storage_dtype),
Scale: T.Tensor((E, N, K // scale_size), storage_dtype),
Bias: T.Tensor((E, N), out_dtype),
# Add fusedmoe tensors
topk_weights: T.Tensor((M * topk), out_dtype),
sorted_token_ids: T.Tensor((padding_M), "int32"),
expert_ids: T.Tensor((padding_M // block_M), "int32"),
C: T.Tensor((M, topk, N), out_dtype),
):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(padding_M, block_M), threads=threads) as (bx, by):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(padding_M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype)
......@@ -280,17 +277,19 @@ def matmul(M,
# May use much more shared memory than necessary
Scale_shared = T.alloc_shared((block_N, K // scale_size), storage_dtype)
T.annotate_layout({
A_shared: tilelang.layout.make_swizzled_layout(A_shared),
B_shared: tilelang.layout.make_swizzled_layout(B_shared),
C_shared: tilelang.layout.make_swizzled_layout(C_shared),
})
T.annotate_layout(
{
A_shared: tilelang.layout.make_swizzled_layout(A_shared),
B_shared: tilelang.layout.make_swizzled_layout(B_shared),
C_shared: tilelang.layout.make_swizzled_layout(C_shared),
}
)
T.use_swizzle(10)
if threads == 512:
T.disable_warp_group_reg_alloc()
T.copy(sorted_token_ids[by * block_M:(by + 1) * block_M], sorted_token_ids_shared)
T.copy(sorted_token_ids[by * block_M : (by + 1) * block_M], sorted_token_ids_shared)
expert_id[0] = expert_ids[by]
# Get the topk weights of each token in the current block
......@@ -300,11 +299,11 @@ def matmul(M,
# Get bias and scale based on the expert id
if with_bias:
T.copy(Bias[expert_id[0], bx * block_N:(bx + 1) * block_N], Bias_shared)
T.copy(Bias[expert_id[0], bx * block_N : (bx + 1) * block_N], Bias_shared)
else:
T.clear(Bias_shared)
T.copy(Scale[expert_id[0], bx * block_N:(bx + 1) * block_N, :], Scale_shared)
T.copy(Scale[expert_id[0], bx * block_N : (bx + 1) * block_N, :], Scale_shared)
for i, j in T.Parallel(block_M, block_N):
C_local[i, j] = Bias_shared[j]
......@@ -317,14 +316,13 @@ def matmul(M,
base = copy_i * threads * 16 + tx * 16
if sorted_token_ids_shared[base // block_K] != -1:
for copy_j in T.vectorized(16):
A_shared[base // block_K, base % block_K +
copy_j] = A[sorted_token_ids_shared[base // block_K] // topk,
k * block_K + base % block_K + copy_j]
A_shared[base // block_K, base % block_K + copy_j] = A[
sorted_token_ids_shared[base // block_K] // topk, k * block_K + base % block_K + copy_j
]
T.copy(B[expert_id[0], bx * block_N, k * block_K // num_elems_per_byte], B_shared)
if fast_dequant:
get_fast_dequant_twiddling_func()(B_shared, B_dequantize_shared, Scale_shared,
k)
get_fast_dequant_twiddling_func()(B_shared, B_dequantize_shared, Scale_shared, k)
else:
get_simple_dequant_func()(B_shared, B_dequantize_shared, Scale_shared, k)
......@@ -338,10 +336,11 @@ def matmul(M,
base = copy_i * threads * 16 + tx * 16
if sorted_token_ids_shared[base // block_N] != -1:
for copy_j in T.vectorized(16):
C[sorted_token_ids_shared[base // block_N] // topk,
sorted_token_ids_shared[base // block_N] % topk, bx * block_N +
base % block_N + copy_j] = C_shared[base // block_N,
base % block_N + copy_j]
C[
sorted_token_ids_shared[base // block_N] // topk,
sorted_token_ids_shared[base // block_N] % topk,
bx * block_N + base % block_N + copy_j,
] = C_shared[base // block_N, base % block_N + copy_j]
return main
......@@ -355,7 +354,7 @@ def ref_moe(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, bloc
assert scale_size == 32 # MXFP4
# Initialize output tensor
C = torch.ones((M, topk, N), dtype=getattr(torch, dtypeC), device='cuda')
C = torch.ones((M, topk, N), dtype=getattr(torch, dtypeC), device="cuda")
# Iterate over sorted_token_ids
for idx in range(len(sorted_token_ids)): # padding_M
......@@ -370,14 +369,11 @@ def ref_moe(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, bloc
# Dequantize the expert weights
B = torch_convert_bit_twiddling(qB[expert_id]) # shape: (N, K)
B *= 2**(
Scale[expert_id][:, (torch.arange(B.shape[1], device=B.device) // scale_size)].to(
torch.bfloat16))
B *= 2 ** (Scale[expert_id][:, (torch.arange(B.shape[1], device=B.device) // scale_size)].to(torch.bfloat16))
# Compute the output for this token-expert pair
# token_embedding @ B.T + bias
output = torch.matmul(token_embedding.to(torch.bfloat16), B.T.to(
torch.bfloat16)) + Bias[expert_id]
output = torch.matmul(token_embedding.to(torch.bfloat16), B.T.to(torch.bfloat16)) + Bias[expert_id]
output = output.to(torch.__getattribute__(dtypeC))
# Apply the topk weight
......@@ -391,14 +387,12 @@ def ref_moe(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, bloc
def get_data(m, n, k, qk, scale_size, topk, E, block_M):
A = torch.empty(m, k, dtype=torch.bfloat16, device='cuda').uniform_(-1, 1)
qB = torch.randint(
0, 256, (E, n, qk), dtype=torch.uint8,
device='cuda') # Quantized weight tensor for E experts.
Scale = torch.randint(0, 8, (E, n, k // scale_size), dtype=torch.uint8, device='cuda')
Bias = torch.empty(E, n, dtype=torch.bfloat16, device='cuda').uniform_(-1, 1)
weights = torch.empty(m, E, dtype=torch.bfloat16, device='cuda').uniform_(-1, 1)
A = torch.empty(m, k, dtype=torch.bfloat16, device="cuda").uniform_(-1, 1)
qB = torch.randint(0, 256, (E, n, qk), dtype=torch.uint8, device="cuda") # Quantized weight tensor for E experts.
Scale = torch.randint(0, 8, (E, n, k // scale_size), dtype=torch.uint8, device="cuda")
Bias = torch.empty(E, n, dtype=torch.bfloat16, device="cuda").uniform_(-1, 1)
weights = torch.empty(m, E, dtype=torch.bfloat16, device="cuda").uniform_(-1, 1)
# topk_weights: Router weights for the top-k experts for each token.
# Shape: (m, topk)
# tokens_experts: A flattened tensor of expert assignments for each token.
......@@ -420,10 +414,7 @@ def get_data(m, n, k, qk, scale_size, topk, E, block_M):
pad_len = ((cnt + block_M - 1) // block_M) * block_M - cnt
if pad_len > 0:
# -1 for padding (`M` instead in vLLM moe_align_block_size())
group_token_ids = torch.cat([
group_token_ids,
torch.full((pad_len,), -1, dtype=group_token_ids.dtype, device='cuda')
])
group_token_ids = torch.cat([group_token_ids, torch.full((pad_len,), -1, dtype=group_token_ids.dtype, device="cuda")])
padded_token_ids.append(group_token_ids)
expert_ids.extend([eid] * ((cnt + block_M - 1) // block_M))
start = end
......@@ -431,21 +422,13 @@ def get_data(m, n, k, qk, scale_size, topk, E, block_M):
# sorted_token_ids: The final flattened and padded tensor of token indices.
sorted_token_ids = torch.cat(padded_token_ids, dim=0).to(torch.int32) # (padding_M,)
# expert_ids: The final tensor of expert IDs corresponding to `sorted_token_ids`.
expert_ids = torch.tensor(expert_ids, dtype=torch.int32, device='cuda') # (padding_M,)
expert_ids = torch.tensor(expert_ids, dtype=torch.int32, device="cuda") # (padding_M,)
padding_M = sorted_token_ids.shape[0] # padding_M: token number after padding
return A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, padding_M
def main(m=256,
n=256,
k=256,
scale_size=32,
topk=4,
E=32,
fast_dequant=True,
with_bias=False,
tune=False):
def main(m=256, n=256, k=256, scale_size=32, topk=4, E=32, fast_dequant=True, with_bias=False, tune=False):
# Tunable parameters
block_M, block_N, block_K = 128, 256, 128 # noqa: F841
num_stages = 1 # noqa: F841
......@@ -456,8 +439,7 @@ def main(m=256,
num_bits = 4
num_elems_per_byte = 8 // num_bits
qk = k // num_elems_per_byte
A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, padding_M = get_data(
m, n, k, qk, scale_size, topk, E, block_M)
A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, padding_M = get_data(m, n, k, qk, scale_size, topk, E, block_M)
if tune:
with set_autotune_inputs([A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids]):
......@@ -510,14 +492,11 @@ def main(m=256,
expert_ids,
)
print('Tilelang kernel run finished.')
print("Tilelang kernel run finished.")
ref_output = ref_moe(
A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids,
block_M=block_M) # Maybe a little bit slow...
ref_output = ref_moe(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, block_M=block_M) # Maybe a little bit slow...
latency = tilelang.profiler.do_bench(
lambda: kernel(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids), warmup=100)
latency = tilelang.profiler.do_bench(lambda: kernel(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids), warmup=100)
print("Tilelang: {:.2f} ms".format(latency))
print("Tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
......@@ -525,32 +504,19 @@ def main(m=256,
max_val = diff.max()
max_idx = diff.argmax()
print(f"max abs diff: {max_val} at index: {max_idx}")
assert_similar(
output, ref_output, name="output",
eps=2e-5) # We care about the similarity rather than abs. difference
assert_similar(output, ref_output, name="output", eps=2e-5) # We care about the similarity rather than abs. difference
print("All checks pass. ✅")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--M", type=int, default=16384, help="M") # From gpt-oss-20b MoE's first gemm
parser.add_argument("--M", type=int, default=16384, help="M") # From gpt-oss-20b MoE's first gemm
parser.add_argument("--N", type=int, default=5760, help="N")
parser.add_argument("--K", type=int, default=2944, help="K")
parser.add_argument("--scale_size", type=int, default=32, help="scale size")
parser.add_argument(
"--topk", type=int, default=4, help="topk") # experts activated for each token
parser.add_argument("--topk", type=int, default=4, help="topk") # experts activated for each token
parser.add_argument("--E", type=int, default=32, help="E") # number of experts
parser.add_argument("--tune", action="store_true", help="tune configs")
args = parser.parse_args()
main(
args.M,
args.N,
args.K,
args.scale_size,
topk=args.topk,
E=args.E,
fast_dequant=True,
with_bias=True,
tune=args.tune)
main(args.M, args.N, args.K, args.scale_size, topk=args.topk, E=args.E, fast_dequant=True, with_bias=True, tune=args.tune)
......@@ -11,7 +11,6 @@ from utils import get_abs_err, get_err_ratio
class RegsiterLossFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x, loss):
ctx.save_for_backward(loss)
......@@ -38,49 +37,43 @@ def ref_deepseek_sparse_attention_innner(
index_sm_scale: Optional[float] = None,
):
dtype = q.dtype
q, kv, index_q, index_k, weights = map(lambda x: x.to(torch.float32),
(q, kv, index_q, index_k, weights))
q, kv, index_q, index_k, weights = map(lambda x: x.to(torch.float32), (q, kv, index_q, index_k, weights))
index_sm_scale = index_q.shape[-1]**-0.5
index_sm_scale = index_q.shape[-1] ** -0.5
b, s = index_q.shape[:2]
# tl_topk_indices = tl_topk_indices.to(torch.int64)
# tl_topk_indices[tl_topk_indices == -1] = s
casual_mask = (torch.arange(s)[:, None] >= torch.arange(s)[None, :]).to(q.device)
index_logits = einsum(index_q, index_k, 'b s1 h k, b s2 k -> b s1 h s2')
index_logits = einsum(index_q, index_k, "b s1 h k, b s2 k -> b s1 h s2")
index_logits = F.relu(index_logits)
index_logits = (index_logits * weights.unsqueeze(-1)).sum(
dim=-2, dtype=torch.float32) * index_sm_scale
index_logits = torch.where(casual_mask, index_logits, float('-inf'))
index_logits = (index_logits * weights.unsqueeze(-1)).sum(dim=-2, dtype=torch.float32) * index_sm_scale
index_logits = torch.where(casual_mask, index_logits, float("-inf"))
topk_indices = torch.topk(index_logits, k=topk, dim=-1).indices
topk_logits = torch.gather(
F.pad(index_logits, (0, 1), value=float('-inf')), dim=-1, index=topk_indices)
topk_logits = torch.gather(F.pad(index_logits, (0, 1), value=float("-inf")), dim=-1, index=topk_indices)
topk_score = F.log_softmax(topk_logits, dim=-1, dtype=torch.float32)
index_topk_score = topk_score
if sm_scale is None:
sm_scale = kv.shape[-1]**-0.5
sm_scale = kv.shape[-1] ** -0.5
h = q.shape[-2]
index_mask = torch.zeros((b, s, s + 1), dtype=torch.bool, device="cuda")\
.scatter_(dim=-1, index=topk_indices, src=torch.ones_like(topk_indices, dtype=torch.bool))[:, :, :-1]
mask = repeat(casual_mask & index_mask, 'b s1 s2 -> b s1 h s2', h=h)
index_mask = torch.zeros((b, s, s + 1), dtype=torch.bool, device="cuda").scatter_(
dim=-1, index=topk_indices, src=torch.ones_like(topk_indices, dtype=torch.bool)
)[:, :, :-1]
mask = repeat(casual_mask & index_mask, "b s1 s2 -> b s1 h s2", h=h)
k, v = kv, kv[..., :dim_v]
logits = einsum(q, k, 'b s1 h d, b s2 d -> b s1 h s2') * sm_scale
logits = torch.where(mask, logits, float('-inf'))
logits = einsum(q, k, "b s1 h d, b s2 d -> b s1 h s2") * sm_scale
logits = torch.where(mask, logits, float("-inf"))
attn_score = F.softmax(logits, dim=-1, dtype=torch.float32)
o = einsum(attn_score, v, 'b s1 h s2, b s2 d -> b s1 h d')
o = einsum(attn_score, v, "b s1 h s2, b s2 d -> b s1 h d")
attn_score = attn_score.sum(dim=-2) # [b, s1, s2]
attn_topk_score = torch.gather(F.pad(attn_score, (0, 1)), dim=-1, index=topk_indices)
attn_topk_score = attn_topk_score / attn_topk_score.sum(dim=-1, keepdim=True)
loss = F.kl_div(
index_topk_score.clip(-100, 0),
attn_topk_score.detach().log().clip(-100, 0),
log_target=True,
reduction="sum")
loss = F.kl_div(index_topk_score.clip(-100, 0), attn_topk_score.detach().log().clip(-100, 0), log_target=True, reduction="sum")
o = register_loss(o, loss)
return o.to(dtype), topk_indices
......@@ -101,11 +94,11 @@ def ref_deepseek_sparse_attention(
all_o, all_topk_indices = [], []
for i in range(offsets.shape[0] - 1):
o, topk_indices = ref_deepseek_sparse_attention_innner(
q[None, offsets[i]:offsets[i + 1]],
kv[None, offsets[i]:offsets[i + 1]],
index_q[None, offsets[i]:offsets[i + 1]],
index_k[None, offsets[i]:offsets[i + 1]],
weights[None, offsets[i]:offsets[i + 1]],
q[None, offsets[i] : offsets[i + 1]],
kv[None, offsets[i] : offsets[i + 1]],
index_q[None, offsets[i] : offsets[i + 1]],
index_k[None, offsets[i] : offsets[i + 1]],
weights[None, offsets[i] : offsets[i + 1]],
topk,
dim_v,
sm_scale,
......@@ -119,7 +112,6 @@ def ref_deepseek_sparse_attention(
class DSAFunction(torch.autograd.Function):
@staticmethod
def forward(
ctx,
......@@ -134,12 +126,9 @@ class DSAFunction(torch.autograd.Function):
sm_scale: Optional[float] = None,
):
# topk_indices, index_score = ref_index_score(index_q, weights, index_k, topk)
topk_indices, index_score = indexer_topk_reducesum_interface(index_q, weights, index_k,
topk, offsets)
o, lse = sparse_mla_fwd_interface(
q, kv.unsqueeze(-2), topk_indices.unsqueeze(-2), offsets, sm_scale=sm_scale, d_v=dim_v)
ctx.save_for_backward(q, kv, index_q, index_k, weights, topk_indices, index_score, o, lse,
offsets)
topk_indices, index_score = indexer_topk_reducesum_interface(index_q, weights, index_k, topk, offsets)
o, lse = sparse_mla_fwd_interface(q, kv.unsqueeze(-2), topk_indices.unsqueeze(-2), offsets, sm_scale=sm_scale, d_v=dim_v)
ctx.save_for_backward(q, kv, index_q, index_k, weights, topk_indices, index_score, o, lse, offsets)
ctx.topk = topk
ctx.dim_v = dim_v
ctx.sm_scale = sm_scale
......@@ -153,19 +142,10 @@ class DSAFunction(torch.autograd.Function):
):
q, kv, index_q, index_k, weights, topk_indices, index_score, o, lse, offsets = ctx.saved_tensors
attn_score = sparse_mla_topk_reducesum_interface(
q, kv.unsqueeze(-2), topk_indices.unsqueeze(-2), lse, offsets,
dim_v=ctx.dim_v).squeeze(-2)
dq, dkv = sparse_mla_bwd(
q,
kv.unsqueeze(-2),
o,
do,
topk_indices.unsqueeze(-2),
lse,
offsets,
sm_scale=ctx.sm_scale)
dindex_q, dweights, dindex_k = indexer_bwd_interface(index_q, weights, index_k, attn_score,
index_score, topk_indices, offsets)
q, kv.unsqueeze(-2), topk_indices.unsqueeze(-2), lse, offsets, dim_v=ctx.dim_v
).squeeze(-2)
dq, dkv = sparse_mla_bwd(q, kv.unsqueeze(-2), o, do, topk_indices.unsqueeze(-2), lse, offsets, sm_scale=ctx.sm_scale)
dindex_q, dweights, dindex_k = indexer_bwd_interface(index_q, weights, index_k, attn_score, index_score, topk_indices, offsets)
return dq, dkv.squeeze(-2), dindex_q, dindex_k, dweights, None, None, None, None
......@@ -209,8 +189,7 @@ def test_kernel(
index_k_grad, index_k.grad = index_k.grad, None
weights_grad, weights.grad = weights.grad, None
ref_o, ref_topk_indices = ref_deepseek_sparse_attention(q, kv, index_q, index_k, weights,
offsets, topk, D)
ref_o, ref_topk_indices = ref_deepseek_sparse_attention(q, kv, index_q, index_k, weights, offsets, topk, D)
ref_o.backward(do)
ref_q_grad, q.grad = q.grad, None
ref_kv_grad, kv.grad = kv.grad, None
......@@ -219,28 +198,20 @@ def test_kernel(
ref_weights_grad, weights.grad = weights.grad, None
print(f"o err: {get_abs_err(o, ref_o):.6f} ratio: {get_err_ratio(o, ref_o):.6f}")
print(
f"q.grad err: {get_abs_err(q_grad, ref_q_grad):.6f} ratio: {get_err_ratio(q_grad, ref_q_grad):.6f}"
)
print(
f"kv.grad err: {get_abs_err(kv_grad, ref_kv_grad):.6f} ratio: {get_err_ratio(kv_grad, ref_kv_grad):.6f}"
)
print(f"q.grad err: {get_abs_err(q_grad, ref_q_grad):.6f} ratio: {get_err_ratio(q_grad, ref_q_grad):.6f}")
print(f"kv.grad err: {get_abs_err(kv_grad, ref_kv_grad):.6f} ratio: {get_err_ratio(kv_grad, ref_kv_grad):.6f}")
print(
f"index_q.grad err: {get_abs_err(index_q_grad[:, :64, :], ref_index_q_grad[:, :64, :]):.6f} ratio: {get_err_ratio(index_q_grad[:, :64, :], ref_index_q_grad[:, :64, :]):.6f}"
)
print(
f"index_k.grad err: {get_abs_err(index_k_grad, ref_index_k_grad):.6f} ratio: {get_err_ratio(index_k_grad, ref_index_k_grad):.6f}"
)
print(
f"weights.grad err: {get_abs_err(weights_grad, ref_weights_grad):.6f} ratio: {get_err_ratio(weights_grad, ref_weights_grad):.6f}"
)
print(f"index_k.grad err: {get_abs_err(index_k_grad, ref_index_k_grad):.6f} ratio: {get_err_ratio(index_k_grad, ref_index_k_grad):.6f}")
print(f"weights.grad err: {get_abs_err(weights_grad, ref_weights_grad):.6f} ratio: {get_err_ratio(weights_grad, ref_weights_grad):.6f}")
intersections = []
for j in range(S):
ref_np = ref_topk_indices[j].cpu().to(torch.int32).numpy()
trt_np = topk_indices[j].cpu().to(torch.int32).numpy()
mask = (trt_np != -1)
mask = trt_np != -1
set_ref = set(ref_np[mask])
set_trt = set(trt_np[mask])
......
......@@ -5,7 +5,9 @@ import functools
from typing import Callable, Any
def tensor_cache(fn: Callable[..., torch.Tensor],) -> Callable[..., torch.Tensor]:
def tensor_cache(
fn: Callable[..., torch.Tensor],
) -> Callable[..., torch.Tensor]:
"""
A decorator that caches the most recent result of a function with tensor inputs.
......@@ -29,10 +31,12 @@ def tensor_cache(fn: Callable[..., torch.Tensor],) -> Callable[..., torch.Tensor
def wrapper(*args: Any, **kwargs: Any) -> Any:
nonlocal last_args, last_kwargs, last_result
if (last_args is not None and last_kwargs is not None) and \
(len(args) == len(last_args) and len(kwargs) == len(last_kwargs)) and \
all(a is b for a, b in zip(args, last_args, strict=False)) and \
all(k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items()):
if (
(last_args is not None and last_kwargs is not None)
and (len(args) == len(last_args) and len(kwargs) == len(last_kwargs))
and all(a is b for a, b in zip(args, last_args, strict=False))
and all(k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items())
):
return last_result
result = fn(*args, **kwargs)
......@@ -56,16 +60,15 @@ def prepare_cu_seqlens_from_lens(
@tensor_cache
def prepare_lens_from_cu_seqlens(cu_seqlens: torch.LongTensor,) -> torch.LongTensor:
def prepare_lens_from_cu_seqlens(
cu_seqlens: torch.LongTensor,
) -> torch.LongTensor:
return torch.diff(cu_seqlens)
@tensor_cache
def prepare_position_ids(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
return torch.cat([
torch.arange(n, dtype=cu_seqlens.dtype, device=cu_seqlens.device)
for n in prepare_lens(cu_seqlens).unbind()
])
return torch.cat([torch.arange(n, dtype=cu_seqlens.dtype, device=cu_seqlens.device) for n in prepare_lens(cu_seqlens).unbind()])
@tensor_cache
......
......@@ -49,17 +49,17 @@ def tl_indexer_bwd_impl(
@T.prim_func
def tl_indexer_bwd_kernel(
IndexQ: T.Tensor(index_q_shape, dtype),
Weights: T.Tensor(weights_shape, dtype),
IndexK: T.Tensor(index_k_shape, dtype),
dIndexQ: T.Tensor(index_q_shape, dtype),
dWeights: T.Tensor(weights_shape, dtype),
dIndexK: T.Tensor(index_k_shape, dtype),
AttnScore: T.Tensor(shape_p, FP32),
IndexScore: T.Tensor(shape_p, FP32),
TopkIndices: T.Tensor(topk_indices_shape, INT32),
Offsets: T.Tensor(offsets_shape, INT32),
TokenIndices: T.Tensor(token_indices_shape, INT32),
IndexQ: T.Tensor(index_q_shape, dtype),
Weights: T.Tensor(weights_shape, dtype),
IndexK: T.Tensor(index_k_shape, dtype),
dIndexQ: T.Tensor(index_q_shape, dtype),
dWeights: T.Tensor(weights_shape, dtype),
dIndexK: T.Tensor(index_k_shape, dtype),
AttnScore: T.Tensor(shape_p, FP32),
IndexScore: T.Tensor(shape_p, FP32),
TopkIndices: T.Tensor(topk_indices_shape, INT32),
Offsets: T.Tensor(offsets_shape, INT32),
TokenIndices: T.Tensor(token_indices_shape, INT32),
):
with T.Kernel(seq_len, threads=num_threads) as (bx):
i_b, i_t = TokenIndices[bx, 0], TokenIndices[bx, 1]
......@@ -81,7 +81,6 @@ def tl_indexer_bwd_impl(
index_q_shared[i, j] = index_q_shared[i, j] * sm_scale
for bi_i in T.Pipelined(num_blocks, num_stages=num_stages):
i_st = bi_i * block_I
i_ed = (bi_i + 1) * block_I
......@@ -91,8 +90,7 @@ def tl_indexer_bwd_impl(
index_k_shared = T.alloc_shared([block_I, dim], dtype=dtype)
for i, j in T.Parallel(block_I, dim):
pos = indices_shared[i]
index_k_shared[i, j] = T.if_then_else((pos > -1) & (pos <= i_t),
IndexK[bos + pos, j], 0)
index_k_shared[i, j] = T.if_then_else((pos > -1) & (pos <= i_t), IndexK[bos + pos, j], 0)
attn_score_shared = T.alloc_shared([block_I], dtype=accum_dtype)
index_score_shared = T.alloc_shared([block_I], dtype=accum_dtype)
......@@ -115,8 +113,7 @@ def tl_indexer_bwd_impl(
# dw
d_weights_i = T.alloc_fragment((block_I, heads), accum_dtype)
for i, j in T.Parallel(block_I, heads):
d_weights_i[i,
j] = (index_score_shared[i] - attn_score_shared[i]) * logits[i, j]
d_weights_i[i, j] = (index_score_shared[i] - attn_score_shared[i]) * logits[i, j]
T.reduce_sum(d_weights_i, d_weights_frag, dim=0, clear=False)
d_logits_qk = T.alloc_shared((block_I, heads), accum_dtype)
......@@ -129,8 +126,7 @@ def tl_indexer_bwd_impl(
d_relu = 1.0
else:
d_relu = 0.0
d_logits_qk[i, j] = (index_score_shared[i] -
attn_score_shared[i]) * d_relu * weights_shared[j]
d_logits_qk[i, j] = (index_score_shared[i] - attn_score_shared[i]) * d_relu * weights_shared[j]
# dq
T.copy(d_logits_qk, d_logits_qk_cast1)
......@@ -157,7 +153,7 @@ def tl_indexer_bwd_impl(
for i, j in T.Parallel(block_I, dim):
pos = indices_shared[i]
if ((pos > -1) & (pos <= i_t)):
if (pos > -1) & (pos <= i_t):
T.atomic_add(dIndexK[bos + pos, j], d_index_k_frag[i, j])
for i, j in T.Parallel(heads, dim):
......@@ -184,40 +180,35 @@ def indexer_bwd_interface(
dweights = torch.zeros_like(weights)
dk = torch.zeros_like(k)
kernel = tl_indexer_bwd_impl(heads, dim, topk)
kernel(q, weights, k, dq, dweights, dk, attn_score, index_score, topk_indices, offsets,
token_indices)
kernel(q, weights, k, dq, dweights, dk, attn_score, index_score, topk_indices, offsets, token_indices)
return dq, dweights, dk
def ref_indexer_bwd(Q: torch.Tensor, Weights: torch.Tensor, K: torch.Tensor,
TopkIndices: torch.Tensor, AttnScore: torch.Tensor,
offsets: torch.Tensor) -> torch.Tensor:
def ref_indexer_bwd(
Q: torch.Tensor, Weights: torch.Tensor, K: torch.Tensor, TopkIndices: torch.Tensor, AttnScore: torch.Tensor, offsets: torch.Tensor
) -> torch.Tensor:
Q.requires_grad_(True)
Weights.requires_grad_(True)
K.requires_grad_(True)
softmax_scale = Q.shape[-1]**-0.5
softmax_scale = Q.shape[-1] ** -0.5
all_loss = []
all_log_topk_prob = []
for i in range(offsets.shape[0] - 1):
assert (offsets[i + 1] - offsets[i]).item() >= TopkIndices.shape[-1]
q = Q[offsets[i]:offsets[i + 1]]
weights = Weights[offsets[i]:offsets[i + 1]]
k = K[offsets[i]:offsets[i + 1]]
topk_indices = TopkIndices[offsets[i]:offsets[i + 1]]
attn_score = AttnScore[offsets[i]:offsets[i + 1]]
q = Q[offsets[i] : offsets[i + 1]]
weights = Weights[offsets[i] : offsets[i + 1]]
k = K[offsets[i] : offsets[i + 1]]
topk_indices = TopkIndices[offsets[i] : offsets[i + 1]]
attn_score = AttnScore[offsets[i] : offsets[i + 1]]
s = q.shape[0]
mask = (torch.arange(s)[:, None] >= torch.arange(s)[None, :]).to(q.device)
logits = einsum(q, k, 's1 h k, s2 k -> s1 h s2') * softmax_scale
logits = einsum(q, k, "s1 h k, s2 k -> s1 h s2") * softmax_scale
logits = F.relu(logits)
score = (logits * weights.unsqueeze(-1)).sum(dim=-2, dtype=torch.float32)
score = torch.where(mask, score, float('-inf'))
score = torch.where(mask, score, float("-inf"))
topk_value = torch.gather(score, dim=-1, index=topk_indices.to(torch.int64))
log_topk_prob = F.log_softmax(topk_value, dim=-1, dtype=torch.float32)
loss = F.kl_div(
log_topk_prob.clip(-100, 0),
attn_score.log().clip(-100, 0),
log_target=True,
reduction="sum")
loss = F.kl_div(log_topk_prob.clip(-100, 0), attn_score.log().clip(-100, 0), log_target=True, reduction="sum")
all_loss.append(loss)
all_log_topk_prob.append(log_topk_prob)
loss = torch.stack(all_loss).sum()
......@@ -244,15 +235,13 @@ def test_kernel(
seq_len = (offsets[i + 1] - offsets[i]).item()
mask = (torch.arange(seq_len)[:, None] >= torch.arange(topk)[None, :]).to(q.device)
logits = torch.ones(seq_len, topk).cuda()
logits = torch.where(mask, logits, float('-inf'))
logits = torch.where(mask, logits, float("-inf"))
attn_score = F.softmax(logits, dim=-1, dtype=torch.float32)
all_attn_score.append(attn_score)
attn_score = torch.cat(all_attn_score, dim=0)
topk_indices = repeat(
torch.arange(topk, dtype=torch.int32).cuda(), 'k -> s k', s=S).contiguous()
index_score, ref_dq, ref_dw, ref_dk = ref_indexer_bwd(q, w, k, topk_indices, attn_score,
offsets)
topk_indices = repeat(torch.arange(topk, dtype=torch.int32).cuda(), "k -> s k", s=S).contiguous()
index_score, ref_dq, ref_dw, ref_dk = ref_indexer_bwd(q, w, k, topk_indices, attn_score, offsets)
dq, dw, dk = indexer_bwd_interface(q, w, k, attn_score, index_score, topk_indices, offsets)
......@@ -261,5 +250,5 @@ def test_kernel(
print(f"dq err: {get_abs_err(dk, ref_dk):.6f} ratio: {get_err_ratio(dk, ref_dk):.6f}")
if __name__ == '__main__':
if __name__ == "__main__":
test_kernel()
......@@ -53,8 +53,8 @@ def tl_indexer_topk_reducesum_impl(
@T.macro
def bitonic_sort(
topk_index_shared: T.SharedBuffer([N], dtype=INT32),
topk_value_shared: T.SharedBuffer([N], dtype=FP32),
topk_index_shared: T.SharedBuffer([N], dtype=INT32),
topk_value_shared: T.SharedBuffer([N], dtype=FP32),
):
T.sync_threads()
for i1 in T.serial(num_iters):
......@@ -62,9 +62,10 @@ def tl_indexer_topk_reducesum_impl(
for i in T.Parallel(N):
ascending = (i & (1 << (i1 + 1))) != 0
j = i ^ (1 << (i1 - i2))
if i < j and \
((ascending and topk_value_shared[i] > topk_value_shared[j]) or (
not ascending and topk_value_shared[i] < topk_value_shared[j])):
if i < j and (
(ascending and topk_value_shared[i] > topk_value_shared[j])
or (not ascending and topk_value_shared[i] < topk_value_shared[j])
):
val = topk_value_shared[i]
topk_value_shared[i] = topk_value_shared[j]
topk_value_shared[j] = val
......@@ -75,13 +76,13 @@ def tl_indexer_topk_reducesum_impl(
@T.prim_func
def tl_indexer_topk_reducesum_kernel(
IndexQ: T.Tensor(index_q_shape, dtype),
Weights: T.Tensor(weights_shape, dtype),
IndexK: T.Tensor(index_k_shape, dtype),
TopkIndices: T.Tensor(topk_indices_shape, INT32),
ReduceSum: T.Tensor(topk_indices_shape, FP32),
Offsets: T.Tensor(offsets_shape, INT32),
TokenIndices: T.Tensor(token_indices_shape, INT32),
IndexQ: T.Tensor(index_q_shape, dtype),
Weights: T.Tensor(weights_shape, dtype),
IndexK: T.Tensor(index_k_shape, dtype),
TopkIndices: T.Tensor(topk_indices_shape, INT32),
ReduceSum: T.Tensor(topk_indices_shape, FP32),
Offsets: T.Tensor(offsets_shape, INT32),
TokenIndices: T.Tensor(token_indices_shape, INT32),
):
with T.Kernel(seq_len, threads=num_threads) as (bx):
i_b, i_t = TokenIndices[bx, 0], TokenIndices[bx, 1]
......@@ -92,7 +93,7 @@ def tl_indexer_topk_reducesum_impl(
topk_value_shared = T.alloc_shared([N], dtype=FP32)
T.fill(topk_index_shared, -1)
T.fill(topk_value_shared, float('-inf'))
T.fill(topk_value_shared, float("-inf"))
T.sync_threads()
index_q_shared = T.alloc_shared([heads, dim], dtype=dtype)
......@@ -113,8 +114,7 @@ def tl_indexer_topk_reducesum_impl(
index_k_shared = T.alloc_shared([block_K, dim], dtype=dtype)
for i, j in T.Parallel(block_K, dim):
index_k_shared[i, j] = T.if_then_else(k_st + i < k_ed, IndexK[bos + k_st + i,
j], 0)
index_k_shared[i, j] = T.if_then_else(k_st + i < k_ed, IndexK[bos + k_st + i, j], 0)
T.sync_threads()
logits = T.alloc_fragment((block_K, heads), FP32)
......@@ -144,7 +144,7 @@ def tl_indexer_topk_reducesum_impl(
T.sync_threads()
for i in T.Parallel(block_K):
if k_st + i > i_t:
logits_sum[i] = float('-inf')
logits_sum[i] = float("-inf")
j = offset + i
topk_index_shared[j] = k_st + i
topk_value_shared[j] = logits_sum[i]
......@@ -209,22 +209,21 @@ def indexer_topk_reducesum_interface(
return topk_indices, topk_score
def ref_index_score(Q: torch.Tensor, Weights: torch.Tensor, K: torch.Tensor, topk: int,
offsets: torch.Tensor) -> torch.Tensor:
def ref_index_score(Q: torch.Tensor, Weights: torch.Tensor, K: torch.Tensor, topk: int, offsets: torch.Tensor) -> torch.Tensor:
all_topk_indices = []
all_topk_score = []
for i in range(offsets.shape[0] - 1):
assert (offsets[i + 1] - offsets[i]).item() >= topk
q = Q[offsets[i]:offsets[i + 1]]
weights = Weights[offsets[i]:offsets[i + 1]]
k = K[offsets[i]:offsets[i + 1]]
softmax_scale = q.shape[-1]**-0.5
q = Q[offsets[i] : offsets[i + 1]]
weights = Weights[offsets[i] : offsets[i + 1]]
k = K[offsets[i] : offsets[i + 1]]
softmax_scale = q.shape[-1] ** -0.5
s = q.shape[0]
mask = (torch.arange(s)[:, None] >= torch.arange(s)[None, :]).to(q.device)
logits = einsum(q, k, 's1 h k, s2 k -> s1 h s2')
logits = einsum(q, k, "s1 h k, s2 k -> s1 h s2")
logits = F.relu(logits)
logits = (logits * weights.unsqueeze(-1)).sum(dim=-2, dtype=torch.float32) * softmax_scale
logits = torch.where(mask, logits, float('-inf'))
logits = torch.where(mask, logits, float("-inf"))
topk_logits, topk_indices = torch.topk(logits, k=topk, dim=-1)
topk_score = F.softmax(topk_logits, dim=-1, dtype=torch.float32)
all_topk_indices.append(topk_indices)
......@@ -265,13 +264,10 @@ def test_kernel(
set_trt = set(trt_np[mask])
intersection = set_ref & set_trt
print("idx:", j, "selected/all:", len(intersection), "/", len(set_ref), "=",
len(intersection) / len(set_ref))
print("idx:", j, "selected/all:", len(intersection), "/", len(set_ref), "=", len(intersection) / len(set_ref))
print(
f"err: {get_abs_err(ref_np_val, trt_np_val):.6f} ratio: {get_err_ratio(ref_np_val, trt_np_val):.6f}"
)
print(f"err: {get_abs_err(ref_np_val, trt_np_val):.6f} ratio: {get_err_ratio(ref_np_val, trt_np_val):.6f}")
if __name__ == '__main__':
if __name__ == "__main__":
test_kernel()
......@@ -19,15 +19,15 @@ def preprocess(
assert dtype == "bfloat16"
assert accum_dtype == "float"
S = T.symbolic('S')
S = T.symbolic("S")
shape = [S, H, D]
@T.prim_func
def preprocess_kernel(
O: T.Tensor(shape, dtype),
dO: T.Tensor(shape, dtype),
Delta: T.Tensor([S, H], accum_dtype),
O: T.Tensor(shape, dtype),
dO: T.Tensor(shape, dtype),
Delta: T.Tensor([S, H], accum_dtype),
):
with T.Kernel(H, T.ceildiv(S, block_ND)) as (bx, by):
o = T.alloc_fragment([block_ND, block_ND], accum_dtype)
......@@ -36,13 +36,12 @@ def preprocess(
acc = T.alloc_fragment([block_ND, block_ND], accum_dtype)
T.clear(acc)
for k in T.Pipelined(T.ceildiv(D, block_ND), num_stages=num_stages):
T.copy(O[by * block_ND:(by + 1) * block_ND, bx, k * block_ND:(k + 1) * block_ND], o)
T.copy(dO[by * block_ND:(by + 1) * block_ND, bx, k * block_ND:(k + 1) * block_ND],
do)
T.copy(O[by * block_ND : (by + 1) * block_ND, bx, k * block_ND : (k + 1) * block_ND], o)
T.copy(dO[by * block_ND : (by + 1) * block_ND, bx, k * block_ND : (k + 1) * block_ND], do)
for i, j in T.Parallel(block_ND, block_ND):
acc[i, j] += o[i, j] * do[i, j]
T.reduce_sum(acc, delta, 1)
T.copy(delta, Delta[by * block_ND:(by + 1) * block_ND, bx])
T.copy(delta, Delta[by * block_ND : (by + 1) * block_ND, bx])
return preprocess_kernel
......@@ -59,19 +58,19 @@ def postprocess(
):
assert dtype == "bfloat16"
assert accum_dtype == "float"
S_kv = T.symbolic('S_kv')
S_kv = T.symbolic("S_kv")
dkv_shape = [S_kv, kv_group, D + D_tail]
@T.prim_func
def postprocess_kernel(
dKV: T.Tensor(dkv_shape, accum_dtype),
dKV_out: T.Tensor(dkv_shape, dtype),
dKV: T.Tensor(dkv_shape, accum_dtype),
dKV_out: T.Tensor(dkv_shape, dtype),
):
with T.Kernel(T.ceildiv(S_kv, block_N), kv_group, threads=threads) as (bx, by):
T.copy(
dKV[bx * block_N:(bx + 1) * block_N, by, :],
dKV_out[bx * block_N:(bx + 1) * block_N, by, :],
dKV[bx * block_N : (bx + 1) * block_N, by, :],
dKV_out[bx * block_N : (bx + 1) * block_N, by, :],
)
return postprocess_kernel
......@@ -82,7 +81,8 @@ def postprocess(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
},
)
def bwd(
H,
D,
......@@ -98,17 +98,17 @@ def bwd(
dtype="bfloat16",
accum_dtype="float",
):
assert is_causal == True, 'non-casual is not supported now'
assert topk % block_size == 0, 'otherwise will load some index=0 thus causing wrong kv to be loaded'
assert is_causal == True, "non-casual is not supported now"
assert topk % block_size == 0, "otherwise will load some index=0 thus causing wrong kv to be loaded"
assert dtype == "bfloat16"
assert accum_dtype == "float"
assert indices_dtype == "int32"
if sm_scale is None:
sm_scale = (D + D_tail)**(-0.5)
sm_scale = (D + D_tail) ** (-0.5)
B_plus_one = T.symbolic('B_plus_one')
S = T.symbolic('S')
B_plus_one = T.symbolic("B_plus_one")
S = T.symbolic("S")
H_kv = H // kv_group
q_shape = [S, H, D + D_tail]
......@@ -132,16 +132,16 @@ def bwd(
@T.prim_func
def sparse_mla_bwd_kernel(
Q: T.Tensor(q_shape, dtype),
KV: T.Tensor(k_shape, dtype),
dO: T.Tensor(o_shape, dtype),
Indices: T.Tensor(indices_shape, indices_dtype),
Lse: T.Tensor(lse_shape, accum_dtype),
Delta: T.Tensor(delta_shape, accum_dtype),
Offsets: T.Tensor(offsets_shape, indices_dtype),
TokenIndices: T.Tensor(token_indices_shape, indices_dtype),
dQ: T.Tensor(q_shape, dtype),
dKV: T.Tensor(k_shape, accum_dtype),
Q: T.Tensor(q_shape, dtype),
KV: T.Tensor(k_shape, dtype),
dO: T.Tensor(o_shape, dtype),
Indices: T.Tensor(indices_shape, indices_dtype),
Lse: T.Tensor(lse_shape, accum_dtype),
Delta: T.Tensor(delta_shape, accum_dtype),
Offsets: T.Tensor(offsets_shape, indices_dtype),
TokenIndices: T.Tensor(token_indices_shape, indices_dtype),
dQ: T.Tensor(q_shape, dtype),
dKV: T.Tensor(k_shape, accum_dtype),
):
with T.Kernel(S, kv_group, threads=threads) as (b_s_i, bz):
Q_shared = T.alloc_shared([padded_H, D], dtype)
......@@ -163,32 +163,32 @@ def bwd(
acc_dkv = T.alloc_fragment([BS, D], accum_dtype)
acc_dkv_tail = T.alloc_fragment([BS, D_tail], accum_dtype)
acc_dkv_shared = T.view(KV_shared, shape=[BS // split_store, D], dtype=accum_dtype)
acc_dkv_tail_shared = T.view(
KV_tail_shared, shape=[BS // split_store, D_tail], dtype=accum_dtype)
acc_dkv_tail_shared = T.view(KV_tail_shared, shape=[BS // split_store, D_tail], dtype=accum_dtype)
b_i, s_i = TokenIndices[b_s_i, 0], TokenIndices[b_s_i, 1]
bos, eos = Offsets[b_i], Offsets[b_i + 1]
max_kv_i = s_i
T.copy(Q[bos + s_i, bz * padded_H:(bz + 1) * padded_H, :D], Q_shared)
T.copy(Q[bos + s_i, bz * padded_H:(bz + 1) * padded_H, D:], Q_tail_shared)
T.copy(dO[bos + s_i, bz * padded_H:(bz + 1) * padded_H, :D], dO_shared)
T.copy(Q[bos + s_i, bz * padded_H : (bz + 1) * padded_H, :D], Q_shared)
T.copy(Q[bos + s_i, bz * padded_H : (bz + 1) * padded_H, D:], Q_tail_shared)
T.copy(dO[bos + s_i, bz * padded_H : (bz + 1) * padded_H, :D], dO_shared)
T.clear(acc_dq)
T.clear(acc_dq_tail)
T.annotate_layout({
dQ_shared: tilelang.layout.make_swizzled_layout(dQ_shared),
dQ_tail_shared: tilelang.layout.make_swizzled_layout(dQ_tail_shared),
})
T.annotate_layout(
{
dQ_shared: tilelang.layout.make_swizzled_layout(dQ_shared),
dQ_tail_shared: tilelang.layout.make_swizzled_layout(dQ_tail_shared),
}
)
# Process each block of indices
for i_i in T.Pipelined(NS, num_stages=num_stages):
# Check which indices are valid
for bi_i in T.Parallel(BS):
mask[bi_i] = (Indices[bos + s_i, bz, i_i * BS + bi_i] <= max_kv_i) & (
Indices[bos + s_i, bz, i_i * BS + bi_i] != -1)
mask[bi_i] = (Indices[bos + s_i, bz, i_i * BS + bi_i] <= max_kv_i) & (Indices[bos + s_i, bz, i_i * BS + bi_i] != -1)
# Compute attention scores
for h_i, bi_i in T.Parallel(padded_H, BS):
......@@ -196,65 +196,33 @@ def bwd(
# Load KV, V for this block of indices
for bi_i, d_i in T.Parallel(BS, D):
KV_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i], bz,
d_i]
KV_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i], bz, d_i]
T.gemm(
Q_shared, KV_shared, acc_p, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
T.gemm(Q_shared, KV_shared, acc_p, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
for bi_i, d_i in T.Parallel(BS, D_tail):
KV_tail_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i],
bz, D + d_i]
T.gemm(
Q_tail_shared,
KV_tail_shared,
acc_p,
transpose_B=True,
policy=T.GemmWarpPolicy.FullCol)
KV_tail_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i], bz, D + d_i]
T.gemm(Q_tail_shared, KV_tail_shared, acc_p, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
for h_i, bi_i in T.Parallel(padded_H, BS):
acc_p[h_i, bi_i] = T.exp(acc_p[h_i, bi_i] * sm_scale -
Lse[bos + s_i, bz * padded_H + h_i])
acc_p[h_i, bi_i] = T.exp(acc_p[h_i, bi_i] * sm_scale - Lse[bos + s_i, bz * padded_H + h_i])
T.copy(acc_p, P_shared_cast)
T.gemm(
dO_shared,
KV_shared,
acc_dp,
transpose_B=True,
policy=T.GemmWarpPolicy.FullCol,
clear_accum=True)
T.gemm(dO_shared, KV_shared, acc_dp, transpose_B=True, policy=T.GemmWarpPolicy.FullCol, clear_accum=True)
for h_i, bi_i in T.Parallel(padded_H, BS):
acc_dp[h_i, bi_i] = acc_p[h_i, bi_i] * (
acc_dp[h_i, bi_i] - Delta[bos + s_i, bz * padded_H + h_i]) * sm_scale
acc_dp[h_i, bi_i] = acc_p[h_i, bi_i] * (acc_dp[h_i, bi_i] - Delta[bos + s_i, bz * padded_H + h_i]) * sm_scale
T.copy(acc_dp, dP_shared_cast)
T.gemm(dP_shared_cast, KV_shared, acc_dq, policy=T.GemmWarpPolicy.FullCol)
T.gemm(dP_shared_cast, KV_tail_shared, acc_dq_tail, policy=T.GemmWarpPolicy.FullCol)
T.gemm(
dP_shared_cast,
Q_shared,
acc_dkv,
transpose_A=True,
policy=T.GemmWarpPolicy.FullCol,
clear_accum=True)
T.gemm(
P_shared_cast,
dO_shared,
acc_dkv,
transpose_A=True,
policy=T.GemmWarpPolicy.FullCol)
T.gemm(dP_shared_cast, Q_shared, acc_dkv, transpose_A=True, policy=T.GemmWarpPolicy.FullCol, clear_accum=True)
T.gemm(P_shared_cast, dO_shared, acc_dkv, transpose_A=True, policy=T.GemmWarpPolicy.FullCol)
T.clear(acc_dkv_tail)
T.gemm(
dP_shared_cast,
Q_tail_shared,
acc_dkv_tail,
transpose_A=True,
policy=T.GemmWarpPolicy.FullCol)
T.gemm(dP_shared_cast, Q_tail_shared, acc_dkv_tail, transpose_A=True, policy=T.GemmWarpPolicy.FullCol)
for s in range(split_store):
for bi_i, d_i in T.Parallel(BS, D):
......@@ -263,44 +231,32 @@ def bwd(
for bi_i, d_i in T.Parallel(BS, D_tail):
if bi_i < BS // split_store:
acc_dkv_tail_shared[bi_i,
d_i] = acc_dkv_tail[bi_i + s * (BS // split_store),
d_i]
acc_dkv_tail_shared[bi_i, d_i] = acc_dkv_tail[bi_i + s * (BS // split_store), d_i]
for bi_i, d_i in T.Parallel(BS // split_store, D // 4):
T.atomic_addx4(
dKV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i + s *
(BS // split_store)], bz, d_i * 4],
acc_dkv_shared[bi_i, d_i * 4])
dKV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i + s * (BS // split_store)], bz, d_i * 4],
acc_dkv_shared[bi_i, d_i * 4],
)
# Atomically update dKV, dKV_tail tensors
for bi_i, d_i in T.Parallel(BS // split_store, D_tail // 4):
T.atomic_addx4(
dKV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i + s *
(BS // split_store)], bz, D + d_i * 4],
acc_dkv_tail_shared[bi_i, d_i * 4])
dKV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i + s * (BS // split_store)], bz, D + d_i * 4],
acc_dkv_tail_shared[bi_i, d_i * 4],
)
# Store the accumulated dQ
T.copy(acc_dq, dQ_shared)
T.copy(acc_dq_tail, dQ_tail_shared)
T.copy(dQ_shared, dQ[bos + s_i, bz * padded_H:(bz + 1) * padded_H, :D])
T.copy(dQ_tail_shared, dQ[bos + s_i, bz * padded_H:(bz + 1) * padded_H, D:])
T.copy(dQ_shared, dQ[bos + s_i, bz * padded_H : (bz + 1) * padded_H, :D])
T.copy(dQ_tail_shared, dQ[bos + s_i, bz * padded_H : (bz + 1) * padded_H, D:])
return sparse_mla_bwd_kernel
def sparse_mla_bwd(q,
kv,
o,
do,
indices,
lse,
offsets,
sm_scale=None,
is_casual=True,
return_kernel=False,
delta=None):
def sparse_mla_bwd(q, kv, o, do, indices, lse, offsets, sm_scale=None, is_casual=True, return_kernel=False, delta=None):
assert q.is_contiguous()
assert kv.is_contiguous()
assert indices.is_contiguous()
......@@ -333,16 +289,9 @@ def sparse_mla_bwd(q,
return dq, dkv
def ref_sparse_mla_bwd_interface(q,
kv,
o,
do,
indices,
lse,
offsets,
sm_scale=None,
is_casual=True):
def ref_sparse_mla_bwd_interface(q, kv, o, do, indices, lse, offsets, sm_scale=None, is_casual=True):
from sparse_mla_fwd import ref_sparse_mla_fwd_interface
q = q.detach().clone()
kv = kv.detach().clone()
q.requires_grad = True
......@@ -352,32 +301,25 @@ def ref_sparse_mla_bwd_interface(q,
return q.grad, kv.grad
def test_sparse_mla_bwd(B=1,
S=2048,
H=64,
HKV=1,
DQKV=576,
DV=512,
topk=512,
dtype=torch.bfloat16,
check_correctness=True):
def test_sparse_mla_bwd(B=1, S=2048, H=64, HKV=1, DQKV=576, DV=512, topk=512, dtype=torch.bfloat16, check_correctness=True):
# Prepare data
q = torch.randn((S, H, DQKV), dtype=dtype, device='cuda').requires_grad_(True)
kv = torch.randn((S, HKV, DQKV), dtype=dtype, device='cuda').requires_grad_(True)
do = torch.randn((S, H, DV), dtype=dtype, device='cuda')
q = torch.randn((S, H, DQKV), dtype=dtype, device="cuda").requires_grad_(True)
kv = torch.randn((S, HKV, DQKV), dtype=dtype, device="cuda").requires_grad_(True)
do = torch.randn((S, H, DV), dtype=dtype, device="cuda")
offsets = torch.tensor([0, S], dtype=torch.int32, device="cuda")
indices = torch.full((S, HKV, topk), S, dtype=torch.int32, device='cuda')
indices = torch.full((S, HKV, topk), S, dtype=torch.int32, device="cuda")
for i in range(offsets.shape[0] - 1):
seq_len = (offsets[i + 1] - offsets[i]).item()
assert seq_len >= topk
for t in range(seq_len):
for h in range(HKV):
i_i = torch.randperm(max(1, t))[:topk]
indices[offsets[i] + t, h, :len(i_i)] = i_i
indices[offsets[i] + t, h, : len(i_i)] = i_i
# Forward
from sparse_mla_fwd import sparse_mla_fwd_interface
tl_out, tl_lse = sparse_mla_fwd_interface(q, kv, indices, offsets)
tl_dq, tl_dkv = sparse_mla_bwd(q, kv, tl_out, do, indices, tl_lse, offsets)
......@@ -388,13 +330,15 @@ def test_sparse_mla_bwd(B=1,
assert_tensors_similar(tl_dkv, ref_dkv, eps=1e-4, name="dkv")
print("assert_tensors_similar passed")
per_token_flop = 2 * sum([
H * DV * topk,
H * DQKV * topk,
H * DQKV * topk,
H * DQKV * topk,
H * DV * topk,
])
per_token_flop = 2 * sum(
[
H * DV * topk,
H * DQKV * topk,
H * DQKV * topk,
H * DQKV * topk,
H * DV * topk,
]
)
from tilelang.profiler import do_bench
def fn():
......@@ -402,19 +346,9 @@ def test_sparse_mla_bwd(B=1,
ms = do_bench(fn, rep=100, warmup=250)
print(f"Average time: {ms:.3f} ms")
print(f'bwd io bandwidth = ',
(B * S * max(DQKV * 2, DQKV + DV) * topk * 2) / (ms * 1e-3) / 1e12)
print(f'bwd tflops = ', per_token_flop * S / (ms * 1e-3) / 1e12)
print(f"bwd io bandwidth = ", (B * S * max(DQKV * 2, DQKV + DV) * topk * 2) / (ms * 1e-3) / 1e12)
print(f"bwd tflops = ", per_token_flop * S / (ms * 1e-3) / 1e12)
if __name__ == "__main__":
test_sparse_mla_bwd(
B=1,
S=2048,
H=64,
HKV=1,
DQKV=576,
DV=512,
topk=512,
dtype=torch.bfloat16,
check_correctness=True)
test_sparse_mla_bwd(B=1, S=2048, H=64, HKV=1, DQKV=576, DV=512, topk=512, dtype=torch.bfloat16, check_correctness=True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment