Commit a63157ea authored by skrider's avatar skrider
Browse files

add test for page table overflow

parent 135a1da6
......@@ -2461,3 +2461,47 @@ def test_flash_attn_varlen_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, caus
assert torch.equal(dv, dv)
assert torch.equal(dk, dk)
assert torch.equal(dq, dq)
@pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("causal", [False, True])
# @pytest.mark.parametrize("causal", [False])
@pytest.mark.parametrize("paged_kv_block_size", [16])
# @pytest.mark.parametrize("has_batch_idx", [False])
@pytest.mark.parametrize("d", [128])
@pytest.mark.parametrize("nheads", [32])
@pytest.mark.parametrize("b", [4])
@pytest.mark.parametrize("n", [10])
@pytest.mark.parametrize("seqlen_q,seqlen_k", [(170, 170)])
def test_flash_attn_paged_kvcache_overflow(
seqlen_q,
seqlen_k,
d,
nheads,
b,
n,
paged_kv_block_size,
causal,
dtype,
):
device = "cuda"
num_blocks = 1000*16//paged_kv_block_size
key_cache = torch.rand([num_blocks, paged_kv_block_size, nheads, d], dtype=dtype, device=device)
value_cache = torch.rand([num_blocks, paged_kv_block_size, nheads, d], dtype=dtype, device=device)
cache_seqlens = torch.zeros(b, dtype=torch.int32, device=device)
for _ in range(n):
query = torch.rand([b, seqlen_q, nheads, d], dtype=dtype, device=device)
key = torch.rand([b, seqlen_k, nheads, d], dtype=dtype, device=device)
value = torch.rand([b, seqlen_k, nheads, d], dtype=dtype, device=device)
block_tables = torch.randint(0, num_blocks, size=(b, (seqlen_k + paged_kv_block_size - 1) // paged_kv_block_size), dtype=torch.int32, device=device)
output = flash_attn_with_kvcache(
query,
key_cache,
value_cache,
k=key,
v=value,
cache_seqlens=cache_seqlens,
block_table=block_tables,
causal=causal,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment