Unverified Commit e8f62b20 authored by Trevor Morris's avatar Trevor Morris Committed by GitHub
Browse files

BLackwell cutlass mla: Add check for bad page size/block num combinations (#5431)

parent 88defc4d
...@@ -74,9 +74,11 @@ def cutlass_mla_decode( ...@@ -74,9 +74,11 @@ def cutlass_mla_decode(
f"but got D_q = {D_q}, D_ckv = {D_ckv}, D_latent = {D_latent}, D_rope = {D_rope}" f"but got D_q = {D_q}, D_ckv = {D_ckv}, D_latent = {D_latent}, D_rope = {D_rope}"
) )
assert H == 128, f"H must be 128, but got {H}" assert H == 128, f"H must be 128, but got {H}"
# TODO: There is currently an illegal memory access issue with page size !=
# 128. Change this when it is fixed. assert len(page_table.shape) == 2
assert PAGE_SIZE == 128, f"PAGE_SIZE must be 128, but got {PAGE_SIZE}" B_block_table, block_num = page_table.shape
assert B_block_table == B_q
assert block_num % (128 / PAGE_SIZE) == 0
# TODO(kaixih@nvidia): support fp8 # TODO(kaixih@nvidia): support fp8
assert q_nope_and_q_pe.dtype in ( assert q_nope_and_q_pe.dtype in (
......
...@@ -39,7 +39,7 @@ def ref_mla( ...@@ -39,7 +39,7 @@ def ref_mla(
@pytest.mark.parametrize("mean_seq_len", [128, 1024, 4096]) @pytest.mark.parametrize("mean_seq_len", [128, 1024, 4096])
@pytest.mark.parametrize("bs", [1, 2, 4]) @pytest.mark.parametrize("bs", [1, 2, 4])
@pytest.mark.parametrize("varlen", [False, True]) @pytest.mark.parametrize("varlen", [False, True])
@pytest.mark.parametrize("block_size", [128]) @pytest.mark.parametrize("block_size", [1, 16, 64, 128])
def test_cutlass_mla_decode( def test_cutlass_mla_decode(
dtype: torch.dtype, mean_seq_len: int, bs: int, varlen: bool, block_size: int dtype: torch.dtype, mean_seq_len: int, bs: int, varlen: bool, block_size: int
): ):
...@@ -62,6 +62,11 @@ def test_cutlass_mla_decode( ...@@ -62,6 +62,11 @@ def test_cutlass_mla_decode(
max_seq_len = seq_lens.max().item() max_seq_len = seq_lens.max().item()
block_num = (max_seq_len + block_size - 1) // block_size block_num = (max_seq_len + block_size - 1) // block_size
# Pad block_num so that small blocks can be packed into full 128-sized CUTLASS tiles.
# One 128-wide tile can hold (128 // block_size) small blocks.
pack_factor = 128 // block_size
block_num = ((block_num + pack_factor - 1) // pack_factor) * pack_factor
q = torch.randn(bs, h_q, d) q = torch.randn(bs, h_q, d)
block_table = torch.randint(0, bs * block_num, (bs, block_num), dtype=torch.int32) block_table = torch.randint(0, bs * block_num, (bs, block_num), dtype=torch.int32)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment