Unverified Commit fc41463c authored by Zhengju Tang's avatar Zhengju Tang Committed by GitHub
Browse files

[BugFix] Robust gemm policy for sparse_mla_fwd in Hopper and Ada Lovelace architectures (#984)

* [BugFix] Robust gemm policy for sparse_mla_fwd in Hopper and Ada Lovelace architectures

* [Lint]
parent b0b5347a
...@@ -136,14 +136,14 @@ def sparse_mla_fwd( ...@@ -136,14 +136,14 @@ def sparse_mla_fwd(
KV_shared, KV_shared,
acc_s, acc_s,
transpose_B=True, transpose_B=True,
policy=T.GemmWarpPolicy.FullCol, policy=T.GemmWarpPolicy.FullRow,
) )
T.gemm( T.gemm(
Q_tail_shared, Q_tail_shared,
K_tail_shared, K_tail_shared,
acc_s, acc_s,
transpose_B=True, transpose_B=True,
policy=T.GemmWarpPolicy.FullCol, policy=T.GemmWarpPolicy.FullRow,
) )
T.copy(m_i, m_i_prev) T.copy(m_i, m_i_prev)
T.reduce_max(acc_s, m_i, dim=1, clear=False) T.reduce_max(acc_s, m_i, dim=1, clear=False)
...@@ -158,7 +158,7 @@ def sparse_mla_fwd( ...@@ -158,7 +158,7 @@ def sparse_mla_fwd(
acc_o[h_i, d_i] = acc_o[h_i, d_i] * alpha[h_i] acc_o[h_i, d_i] = acc_o[h_i, d_i] * alpha[h_i]
T.copy(acc_s, S_shared) T.copy(acc_s, S_shared)
T.gemm(S_shared, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol) T.gemm(S_shared, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
# Rescale # Rescale
for h_i, d_i in T.Parallel(H_per_block, D): for h_i, d_i in T.Parallel(H_per_block, D):
...@@ -174,7 +174,15 @@ def sparse_mla_fwd( ...@@ -174,7 +174,15 @@ def sparse_mla_fwd(
return main return main
def sparse_mla_fwd_interface(q, kv, indices, sm_scale=None, return_p_sum: bool = False, d_v=512): def sparse_mla_fwd_interface(q,
kv,
indices,
sm_scale=None,
return_p_sum: bool = False,
d_v=512,
block_I=64,
num_stages=2,
threads=256):
is_casual = True is_casual = True
assert return_p_sum == False, "This kernel file is for fwd only" assert return_p_sum == False, "This kernel file is for fwd only"
assert q.is_contiguous() and kv.is_contiguous() and indices.is_contiguous() assert q.is_contiguous() and kv.is_contiguous() and indices.is_contiguous()
...@@ -190,7 +198,17 @@ def sparse_mla_fwd_interface(q, kv, indices, sm_scale=None, return_p_sum: bool = ...@@ -190,7 +198,17 @@ def sparse_mla_fwd_interface(q, kv, indices, sm_scale=None, return_p_sum: bool =
_, _, _, topk = indices.shape _, _, _, topk = indices.shape
assert indices.shape == (batch, seq_len, kv_group, topk) assert indices.shape == (batch, seq_len, kv_group, topk)
kernel = sparse_mla_fwd(heads, dim, tail_dim, topk, kv_group, sm_scale, is_casual) kernel = sparse_mla_fwd(
heads,
dim,
tail_dim,
topk,
kv_group,
sm_scale,
is_casual,
block_I=block_I,
num_stages=num_stages,
threads=threads)
out, lse = kernel(q, kv, indices) out, lse = kernel(q, kv, indices)
return out, lse return out, lse
...@@ -241,7 +259,10 @@ def test_sparse_mla_fwd(B=1, ...@@ -241,7 +259,10 @@ def test_sparse_mla_fwd(B=1,
DV=512, DV=512,
topk=2048, topk=2048,
dtype=torch.bfloat16, dtype=torch.bfloat16,
check_correctness=True): check_correctness=True,
block_I=64,
num_stages=2,
threads=256):
torch.random.manual_seed(0) torch.random.manual_seed(0)
q = torch.randn((B, S, H, DQK), dtype=dtype, device="cuda").requires_grad_(True) q = torch.randn((B, S, H, DQK), dtype=dtype, device="cuda").requires_grad_(True)
kv = torch.randn((B, SKV, HKV, DQK), dtype=dtype, device="cuda").requires_grad_(True) kv = torch.randn((B, SKV, HKV, DQK), dtype=dtype, device="cuda").requires_grad_(True)
...@@ -253,7 +274,8 @@ def test_sparse_mla_fwd(B=1, ...@@ -253,7 +274,8 @@ def test_sparse_mla_fwd(B=1,
i_i = torch.randperm(max(1, t))[:topk] i_i = torch.randperm(max(1, t))[:topk]
indices[b, t, h, :len(i_i)] = i_i indices[b, t, h, :len(i_i)] = i_i
tl_out, tl_lse = sparse_mla_fwd_interface(q, kv, indices) tl_out, tl_lse = sparse_mla_fwd_interface(
q, kv, indices, block_I=block_I, num_stages=num_stages, threads=threads)
if check_correctness: if check_correctness:
# otherwise may cause out of memory # otherwise may cause out of memory
...@@ -262,7 +284,8 @@ def test_sparse_mla_fwd(B=1, ...@@ -262,7 +284,8 @@ def test_sparse_mla_fwd(B=1,
print("assert_tensors_similar passed") print("assert_tensors_similar passed")
def fn(): def fn():
return sparse_mla_fwd_interface(q, kv, indices) return sparse_mla_fwd_interface(
q, kv, indices, block_I=block_I, num_stages=num_stages, threads=threads)
from tilelang.profiler import do_bench from tilelang.profiler import do_bench
...@@ -287,4 +310,7 @@ if __name__ == "__main__": ...@@ -287,4 +310,7 @@ if __name__ == "__main__":
DV=512, DV=512,
topk=2048, topk=2048,
dtype=torch.bfloat16, dtype=torch.bfloat16,
check_correctness=True) check_correctness=True,
block_I=64,
num_stages=2,
threads=256)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment