Unverified Commit 6dda13c8 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[Misc] Add sliding window to flashinfer test (#21282)


Signed-off-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
parent 6b46c4b6
...@@ -77,6 +77,7 @@ def ref_paged_attn( ...@@ -77,6 +77,7 @@ def ref_paged_attn(
@pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0]) @pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
@pytest.mark.parametrize("sliding_window", [None, 64])
@torch.inference_mode @torch.inference_mode
def test_flashinfer_decode_with_paged_kv( def test_flashinfer_decode_with_paged_kv(
kv_lens: list[int], kv_lens: list[int],
...@@ -85,6 +86,7 @@ def test_flashinfer_decode_with_paged_kv( ...@@ -85,6 +86,7 @@ def test_flashinfer_decode_with_paged_kv(
dtype: torch.dtype, dtype: torch.dtype,
block_size: int, block_size: int,
soft_cap: Optional[float], soft_cap: Optional[float],
sliding_window: Optional[int],
) -> None: ) -> None:
torch.set_default_device("cuda") torch.set_default_device("cuda")
current_platform.seed_everything(0) current_platform.seed_everything(0)
...@@ -136,17 +138,20 @@ def test_flashinfer_decode_with_paged_kv( ...@@ -136,17 +138,20 @@ def test_flashinfer_decode_with_paged_kv(
use_tensor_cores=( use_tensor_cores=(
(num_query_heads//num_kv_heads) > 4) (num_query_heads//num_kv_heads) > 4)
) )
wrapper.plan(kv_indptr, wrapper.plan(
kv_indices, kv_indptr,
kv_last_page_lens, kv_indices,
num_query_heads, kv_last_page_lens,
num_kv_heads, num_query_heads,
head_size, num_kv_heads,
block_size, head_size,
"NONE", block_size,
q_data_type=dtype, "NONE",
kv_data_type=dtype, window_left=sliding_window - 1 if sliding_window is not None else -1,
logits_soft_cap=soft_cap) q_data_type=dtype,
kv_data_type=dtype,
logits_soft_cap=soft_cap,
)
output = wrapper.run(query, key_value_cache) output = wrapper.run(query, key_value_cache)
...@@ -157,7 +162,8 @@ def test_flashinfer_decode_with_paged_kv( ...@@ -157,7 +162,8 @@ def test_flashinfer_decode_with_paged_kv(
kv_lens=kv_lens, kv_lens=kv_lens,
block_tables=block_tables, block_tables=block_tables,
scale=scale, scale=scale,
soft_cap=soft_cap) soft_cap=soft_cap,
sliding_window=sliding_window)
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2), \ torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - ref_output))}" f"{torch.max(torch.abs(output - ref_output))}"
...@@ -168,12 +174,17 @@ def test_flashinfer_decode_with_paged_kv( ...@@ -168,12 +174,17 @@ def test_flashinfer_decode_with_paged_kv(
@pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0]) @pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
@pytest.mark.parametrize("sliding_window", [None, 64])
@torch.inference_mode @torch.inference_mode
def test_flashinfer_prefill_with_paged_kv(seq_lens: list[tuple[int, int]], def test_flashinfer_prefill_with_paged_kv(
num_heads: tuple[int, int], seq_lens: list[tuple[int, int]],
head_size: int, dtype: torch.dtype, num_heads: tuple[int, int],
block_size: int, head_size: int,
soft_cap: Optional[float]) -> None: dtype: torch.dtype,
block_size: int,
soft_cap: Optional[float],
sliding_window: Optional[int],
) -> None:
torch.set_default_device("cuda") torch.set_default_device("cuda")
current_platform.seed_everything(0) current_platform.seed_everything(0)
num_seqs = len(seq_lens) num_seqs = len(seq_lens)
...@@ -242,6 +253,7 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: list[tuple[int, int]], ...@@ -242,6 +253,7 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: list[tuple[int, int]],
num_kv_heads, num_kv_heads,
head_size, head_size,
block_size, block_size,
window_left=sliding_window - 1 if sliding_window is not None else -1,
q_data_type=dtype, q_data_type=dtype,
kv_data_type=dtype, kv_data_type=dtype,
logits_soft_cap=soft_cap, logits_soft_cap=soft_cap,
...@@ -259,7 +271,8 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: list[tuple[int, int]], ...@@ -259,7 +271,8 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: list[tuple[int, int]],
kv_lens=kv_lens, kv_lens=kv_lens,
block_tables=block_tables, block_tables=block_tables,
scale=scale, scale=scale,
soft_cap=soft_cap) soft_cap=soft_cap,
sliding_window=sliding_window)
torch.testing.assert_close(output, ref_output, atol=5e-2, rtol=1e-2), \ torch.testing.assert_close(output, ref_output, atol=5e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - ref_output))}" f"{torch.max(torch.abs(output - ref_output))}"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment