test_attention_backends.py 22.9 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for v1 attention backends without GPUModelRunner dependency."""
4

5
from functools import partial
6
7
8

import pytest
import torch
9
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
10

11
12
13
14
15
from tests.v1.attention.utils import (
    BatchSpec,
    create_common_attn_metadata,
    create_standard_kv_cache_spec,
    create_vllm_config,
16
    try_get_attention_backend,
17
)
18
from vllm.attention.backends.registry import AttentionBackendEnum
19
20
from vllm.config import ModelConfig
from vllm.platforms import current_platform
21
from vllm.utils.math_utils import cdiv
22
23
24
25
26
from vllm.utils.torch_utils import (
    STR_DTYPE_TO_TORCH_DTYPE,
    is_torch_equal_or_newer,
    set_random_seed,
)
27
28
29
30
from vllm.v1.attention.backends.utils import (
    CommonAttentionMetadata,
    set_kv_cache_layout,
)
31
32
33
from vllm.v1.kv_cache_interface import FullAttentionSpec

BACKENDS_TO_TEST = [
34
35
36
37
38
    AttentionBackendEnum.FLASH_ATTN,
    AttentionBackendEnum.FLASHINFER,
    AttentionBackendEnum.FLEX_ATTENTION,
    AttentionBackendEnum.TRITON_ATTN,
    AttentionBackendEnum.TREE_ATTN,
39
    "FLEX_ATTENTION_SLOW",
40
41
42
43
44
45
]

# Remove flashinfer from the list if it's not available
try:
    import flashinfer  # noqa: F401
except ImportError:
46
    BACKENDS_TO_TEST.remove(AttentionBackendEnum.FLASHINFER)
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65


def _convert_dtype_to_torch(dtype):
    """Convert ModelDType to torch.dtype."""
    if isinstance(dtype, str):
        if dtype == "auto":
            return torch.float16  # Default dtype for testing
        elif dtype in STR_DTYPE_TO_TORCH_DTYPE:
            return STR_DTYPE_TO_TORCH_DTYPE[dtype]
        else:
            raise ValueError(f"Unknown dtype: {dtype}")
    elif isinstance(dtype, torch.dtype):
        return dtype
    else:
        raise ValueError(f"Unknown dtype: {dtype}")


# Define common batch configurations
BATCH_SPECS = {
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
    "small_decode": BatchSpec(seq_lens=[32, 40], query_lens=[1, 1]),
    "small_prefill": BatchSpec(seq_lens=[32, 40], query_lens=[8, 8]),
    "mixed_small": BatchSpec(seq_lens=[32, 40, 48, 56], query_lens=[1, 1, 5, 5]),
    "medium_decode": BatchSpec(
        seq_lens=[128, 256, 512, 1024, 128, 256, 512, 1024],
        query_lens=[1, 1, 1, 1, 1, 1, 1, 1],
    ),
    "medium_prefill": BatchSpec(
        seq_lens=[256, 512, 1024, 2048], query_lens=[16, 16, 16, 16]
    ),
    "mixed_medium": BatchSpec(
        seq_lens=[512, 1024, 2048, 512, 1024, 2048], query_lens=[1, 1, 1, 7, 7, 7]
    ),
    "large_decode": BatchSpec(seq_lens=[2048] * 32, query_lens=[1] * 32),
    "large_prefill": BatchSpec(seq_lens=[4096] * 8, query_lens=[32] * 8),
81
82
83
    "mixed_large": BatchSpec(
        seq_lens=[1024, 2048, 4096, 1024, 2048, 4096], query_lens=[1, 1, 1, 32, 32, 32]
    ),
84
85
    "single_decode": BatchSpec(seq_lens=[1024], query_lens=[1]),
    "single_prefill": BatchSpec(seq_lens=[1024], query_lens=[64]),
86
87
88
89
}


def create_and_prepopulate_kv_cache(
90
91
92
93
94
95
96
97
98
99
100
    k_contexts: list[torch.Tensor],
    v_contexts: list[torch.Tensor],
    block_size: int,
    num_kv_heads: int,
    head_size: int,
    dtype: torch.dtype,
    device: torch.device,
    num_blocks: int,
    common_attn_metadata: CommonAttentionMetadata,
    randomize_blocks: bool = True,
) -> torch.Tensor:
101
    """Create and prepopulate a KV cache with context data.
102

103
104
105
106
107
108
109
110
111
112
113
    Args:
        k_contexts: List of key context tensors for each sequence
        v_contexts: List of value context tensors for each sequence
        seq_lens: List of sequence lengths
        block_size: Size of each block
        num_kv_heads: Number of KV heads
        head_size: Size of each head
        dtype: Data type for the cache
        device: Device to create the cache on
        num_blocks: Total number of blocks in the cache
        block_table: Block table tensor to populate
114
        randomize_blocks: Whether to randomly permute blocks
115
                          or use sequential order
116

117
118
119
120
121
    Returns:
        Tuple of (kv_cache, updated_block_table)
    """
    batch_size = len(k_contexts)
    seq_lens = common_attn_metadata.seq_lens_cpu
122
123
124
125
    query_lens = (
        common_attn_metadata.query_start_loc_cpu[1:]
        - common_attn_metadata.query_start_loc_cpu[:-1]
    )
126
127
128
129
130
    context_lens = common_attn_metadata.num_computed_tokens_cpu
    block_table = common_attn_metadata.block_table_tensor
    slot_mapping = common_attn_metadata.slot_mapping

    # Create KV cache
131
132
133
    kv_cache = torch.empty(
        2, num_blocks, block_size, num_kv_heads, head_size, dtype=dtype, device=device
    )
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
    kv_cache_flat = kv_cache.view(2, -1, num_kv_heads, head_size)

    # Populate the cache with the context tokens
    # Start from block_id=1 since block_id=0 is considered the null block
    start_block_idx = 1
    for i in range(batch_size):
        k_context, v_context = k_contexts[i], v_contexts[i]
        start = start_block_idx * block_size
        end = start + k_context.shape[0]
        kv_cache_flat[0, start:end, ...] = k_context
        kv_cache_flat[1, start:end, ...] = v_context

        # Stay block aligned and allocate enough blocks for the new tokens
        start_block_idx += cdiv(int(seq_lens[i]), block_size)

    blocks_end = start_block_idx

    # Permute the context blocks (excluding block 0 which is null)
    if randomize_blocks:
Matthew Bonanni's avatar
Matthew Bonanni committed
153
154
        # Random permutation starting from block 1
        perm = torch.randperm(blocks_end - 1) + 1
155
    else:
Matthew Bonanni's avatar
Matthew Bonanni committed
156
157
        # Sequential order starting from block 1
        perm = torch.arange(1, blocks_end)
158
159

    inv_perm = torch.zeros(blocks_end, dtype=torch.long, device=device)
Matthew Bonanni's avatar
Matthew Bonanni committed
160
161
    # Add 1 to account for starting from block 1
    inv_perm[1:] = torch.argsort(perm) + 1
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
    kv_cache[:, 1:blocks_end, ...] = kv_cache[:, perm, ...]

    # Construct the right block table
    # Start from block_id=1 since block_id=0 is considered the null block
    start_block_idx = 1
    for i in range(batch_size):
        num_blocks_for_seq = cdiv(int(seq_lens[i]), block_size)
        start = start_block_idx
        end = start + num_blocks_for_seq
        block_table[i, :num_blocks_for_seq] = inv_perm[start:end]
        start_block_idx += num_blocks_for_seq

        # Create a realistic slot mapping that corresponds to the block table
    for i in range(batch_size):
        token_offsets = torch.arange(int(query_lens[i])) + int(context_lens[i])
        block_indices = token_offsets // block_size
        token_inter_block_offsets = token_offsets % block_size
        start = common_attn_metadata.query_start_loc_cpu[i]
        end = common_attn_metadata.query_start_loc_cpu[i + 1]
        slot_mapping[start:end] = block_table[
182
183
            i, block_indices
        ] * block_size + token_inter_block_offsets.to(device)
184
185
186
187
188
189
190
191
192
193
194
195

    return kv_cache


class MockAttentionLayer:
    """A mock attention layer for testing."""

    def __init__(self, device: torch.device):
        self._q_scale = torch.tensor(1.0, device=device)
        self._k_scale = torch.tensor(1.0, device=device)
        self._v_scale = torch.tensor(1.0, device=device)
        # Add float versions for flashinfer
196
        self._q_scale_float = 1.0
197
198
199
200
        self._k_scale_float = 1.0
        self._v_scale_float = 1.0


201
def run_attention_backend(
202
    backend: AttentionBackendEnum,
203
204
205
206
207
208
209
210
211
    kv_cache_spec: FullAttentionSpec,
    layer_names: list[str],
    vllm_config,
    device: torch.device,
    common_attn_metadata: CommonAttentionMetadata,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    kv_cache: torch.Tensor,
212
    sliding_window: int | None = None,
213
) -> torch.Tensor:
214
215
    """Run attention computation using the specified backend's AttentionImpl."""

216
217
218
219
220
    # Handle special case for FLEX_ATTENTION_SLOW
    actual_backend = backend

    use_direct_block_mask = is_torch_equal_or_newer("2.9.0.dev0")
    if backend == "FLEX_ATTENTION_SLOW":
221
        actual_backend = AttentionBackendEnum.FLEX_ATTENTION
222
223
        use_direct_block_mask = False

224
    builder_cls, impl_cls = try_get_attention_backend(actual_backend)
225
226

    # Mock flashinfer's get_per_layer_parameters if needed
227
    if actual_backend == AttentionBackendEnum.FLASHINFER:
228
229
        import unittest.mock

230
        from vllm.v1.attention.backends.utils import PerLayerParameters
231

232
        def mock_get_per_layer_parameters(vllm_config, layer_names, impl_cls):
233
234
235
            # Return mock parameters for a single layer
            head_size = vllm_config.model_config.get_head_size()
            return {
236
                layer_name: PerLayerParameters(
237
238
                    window_left=-1,  # No sliding window
                    logits_soft_cap=0.0,  # No soft cap
239
                    sm_scale=1.0 / (head_size**0.5),  # Standard scale
240
                )
241
                for layer_name in layer_names
242
243
244
            }

        with unittest.mock.patch(
245
246
247
248
            "vllm.v1.attention.backends.flashinfer.get_per_layer_parameters",
            mock_get_per_layer_parameters,
        ):
            builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device)
249
250
251
252
253
254
            attn_metadata = builder.build(
                common_prefix_len=0,
                common_attn_metadata=common_attn_metadata,
            )
    else:
        # Build metadata
255
        builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device)
256
        if actual_backend == AttentionBackendEnum.FLEX_ATTENTION:
257
            builder.direct_build = use_direct_block_mask
258
259
260
261
262
263
264
        attn_metadata = builder.build(
            common_prefix_len=0,
            common_attn_metadata=common_attn_metadata,
        )

    # Instantiate implementation
    num_heads = vllm_config.model_config.get_num_attention_heads(
265
266
        vllm_config.parallel_config
    )
267
    num_kv_heads = vllm_config.model_config.get_num_kv_heads(
268
269
        vllm_config.parallel_config
    )
270
271
272
273
274
275
276
277
    head_size = vllm_config.model_config.get_head_size()
    scale = 1.0 / (head_size**0.5)
    impl = impl_cls(
        num_heads=num_heads,
        head_size=head_size,
        scale=scale,
        num_kv_heads=num_kv_heads,
        alibi_slopes=None,
278
        sliding_window=sliding_window,
279
280
281
282
283
284
285
286
287
288
        kv_cache_dtype="auto",
    )

    # Create mock layer and output buffer
    mock_layer = MockAttentionLayer(device)
    output = torch.empty_like(query)

    # Run forward pass
    # NOTE: The query, key, and value are already shaped correctly
    # in the calling test function.
289
290
291
    output = impl.forward(
        mock_layer, query, key, value, kv_cache, attn_metadata, output=output
    )
292
293
294
295

    return output


296
297
298
def _test_backend_correctness(
    batch_spec: BatchSpec,
    model: str,
299
    backend_to_test: list[AttentionBackendEnum | str],
300
301
302
303
304
    mask_mod,
    *,
    block_size: int = 16,
    atol: float = 1e-2,
    rtol: float = 1e-2,
305
    tensor_parallel_size: int = 1,
306
):
307
308
309
310
311
312
313
314
315
316
317
318
319
320
    """
    Test that all backends produce similar outputs to a reference implementation
    using torch.nn.functional.scaled_dot_product_attention.

    This test works by:
    1. Generating a batch of sequences with specified context and query lengths.
    2. Computing a ground-truth attention output using torch.sdpa on
       contiguous Q, K, and V tensors.
    3. Simulating vLLM's paged KV cache: It takes the context portion of the
       K/V tensors and manually places them into a paged buffer according to
       the test's (randomly generated) block table.
    4. Running each vLLM attention backend with the new queries and the
       simulated paged KV cache.
    5. Comparing the vLLM backend's output to the ground-truth SDPA output.
321
322
323
324
325

    Note: When tensor_parallel_size > 1, we simulate the head partitioning
    by overriding the model config to use fewer heads, without requiring
    multiple GPUs. This tests that backends work correctly with different
    head counts.
326
    """
327
    set_random_seed(42)
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345

    hf_config_override = None
    if tensor_parallel_size > 1:
        from vllm.config import ModelConfig

        temp_config = ModelConfig(model=model, max_model_len=1)
        original_num_heads = temp_config.hf_text_config.num_attention_heads
        original_num_kv_heads = getattr(
            temp_config.hf_text_config, "num_key_value_heads", None
        )
        hf_config_override = {
            "num_attention_heads": original_num_heads // tensor_parallel_size,
        }
        if original_num_kv_heads is not None:
            hf_config_override["num_key_value_heads"] = max(
                1, original_num_kv_heads // tensor_parallel_size
            )

346
347
    vllm_config = create_vllm_config(
        model_name=model,
348
        tensor_parallel_size=1,  # Always use TP=1 to avoid multi-GPU requirements
349
350
351
        max_model_len=max(batch_spec.seq_lens),
        block_size=block_size,
        num_gpu_blocks=8192,
352
        hf_config_override=hf_config_override,
353
    )
354
355
356
357
358
359
360
361
362
    device = torch.device("cuda:0")

    kv_cache_spec = create_standard_kv_cache_spec(vllm_config)

    # 1. Setup
    batch_size = batch_spec.batch_size
    seq_lens = batch_spec.seq_lens
    query_lens = batch_spec.query_lens
    num_q_heads = vllm_config.model_config.get_num_attention_heads(
363
364
        vllm_config.parallel_config
    )
365
    num_kv_heads = vllm_config.model_config.get_num_kv_heads(
366
367
        vllm_config.parallel_config
    )
368
    head_size = vllm_config.model_config.get_head_size()
369
    sliding_window = vllm_config.model_config.get_sliding_window()
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
    dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype)
    block_size = vllm_config.cache_config.block_size
    scale = 1.0 / (head_size**0.5)

    # 2. Generate data and compute SDPA reference output
    all_q_vllm, all_k_vllm, all_v_vllm = [], [], []
    all_sdpa_outputs = []
    k_contexts, v_contexts = [], []

    for i in range(batch_size):
        s_len = seq_lens[i]
        q_len = query_lens[i]
        context_len = s_len - q_len

        # Generate Q, K, V for the whole sequence to be used in SDPA
385
386
387
        q = torch.randn(q_len, num_q_heads, head_size, dtype=dtype, device=device)
        k_full = torch.randn(s_len, num_kv_heads, head_size, dtype=dtype, device=device)
        v_full = torch.randn(s_len, num_kv_heads, head_size, dtype=dtype, device=device)
388
389
390
391
392
393
394
395
396

        # SDPA expects (N, H, L, D), so unsqueeze batch and permute
        q_sdpa_in = q.unsqueeze(0).transpose(1, 2)
        k_sdpa_in = k_full.unsqueeze(0).transpose(1, 2)
        v_sdpa_in = v_full.unsqueeze(0).transpose(1, 2)

        if num_q_heads != num_kv_heads:
            assert num_q_heads % num_kv_heads == 0, (
                f"num_q_heads ({num_q_heads}) must be divisible by "
397
398
                f"num_kv_heads ({num_kv_heads})"
            )
399
400
401
402
403
404
405
            repeats = num_q_heads // num_kv_heads
            k_sdpa_in = k_sdpa_in.repeat_interleave(repeats, dim=1)
            v_sdpa_in = v_sdpa_in.repeat_interleave(repeats, dim=1)

        # Create causal mask: query token i attends to positions 0 to
        #  (context_len + i)
        kv_len = s_len
406
407

        final_mask_mod = partial(mask_mod, context_len=context_len)
408
409
410
411
412
413
414
415
416
417
418
        block_mask = create_block_mask(
            final_mask_mod, B=None, H=None, Q_LEN=q_len, KV_LEN=kv_len, device=device
        )
        sdpa_out_i = flex_attention(
            q_sdpa_in,
            k_sdpa_in,
            v_sdpa_in,
            block_mask=block_mask,
            scale=scale,
            enable_gqa=True,
        )
419

420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
        all_sdpa_outputs.append(sdpa_out_i.transpose(1, 2).squeeze(0))

        # Inputs for vLLM backends are just the new tokens
        all_q_vllm.append(q)
        all_k_vllm.append(k_full[context_len:])
        all_v_vllm.append(v_full[context_len:])

        # Contextual K/V data used to populate the paged cache
        k_contexts.append(k_full[:context_len])
        v_contexts.append(v_full[:context_len])

    query_vllm = torch.cat(all_q_vllm, dim=0)
    key_vllm = torch.cat(all_k_vllm, dim=0)
    value_vllm = torch.cat(all_v_vllm, dim=0)
    sdpa_output = torch.cat(all_sdpa_outputs, dim=0)

    common_attn_metadata = create_common_attn_metadata(
437
438
        batch_spec, vllm_config.cache_config.block_size, device
    )
439
440
441
442
443
444
445
446
447
448
449
450

    # 3. Simulate Paged KV Cache and a realistic slot_mapping
    kv_cache = create_and_prepopulate_kv_cache(
        k_contexts=k_contexts,
        v_contexts=v_contexts,
        block_size=block_size,
        num_kv_heads=num_kv_heads,
        head_size=head_size,
        dtype=dtype,
        device=device,
        num_blocks=vllm_config.cache_config.num_gpu_blocks or 1000,
        common_attn_metadata=common_attn_metadata,
451
452
        randomize_blocks=True,
    )
453
454
455
456

    # 4. Run vLLM backends and compare
    # Note: flex_attention has known Triton kernel compatibility issues
    # with test infrastructures
457
    for backend_name in backend_to_test:
458
459
        # FlashAttentionm + FlexAttention:
        #   [2, num_blocks, block_size, num_kv_heads, head_size]
460
        # FlashInfer + Triton:
461
462
463
        #   [num_blocks, 2, block_size, num_kv_heads, head_size]
        # Select the appropriate KV cache format for each backend
        kv_cache_for_backend = kv_cache
464
        reset_kv_cache_layout = False
465
466
467
468
        if backend_name in (
            AttentionBackendEnum.FLASHINFER,
            AttentionBackendEnum.TRITON_ATTN,
        ):
469
470
            kv_cache_for_backend = kv_cache.transpose(0, 1)

471
        if backend_name == AttentionBackendEnum.FLASHINFER:
472
            # For FlashInfer default to HND layout and
473
474
475
            kv_cache_for_backend = (
                kv_cache_for_backend.transpose(2, 3).contiguous().transpose(2, 3)
            )
476
            set_kv_cache_layout("HND")
477
            reset_kv_cache_layout = True
478
        elif backend_name == AttentionBackendEnum.TRITON_ATTN:
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
            kv_cache_for_backend = kv_cache_for_backend.contiguous()

        try:
            backend_output = run_attention_backend(
                backend_name,
                kv_cache_spec,
                ["placeholder"],
                vllm_config,
                device,
                common_attn_metadata,
                query_vllm,
                key_vllm,
                value_vllm,
                kv_cache_for_backend,
                sliding_window=sliding_window,
            )
        finally:
            if reset_kv_cache_layout:
                set_kv_cache_layout(None)
498
499
500
501

        # Check shape and dtype consistency
        assert backend_output.shape == sdpa_output.shape, (
            f"[{backend_name}] shape {backend_output.shape} != "
502
503
            f"SDPA shape {sdpa_output.shape}"
        )
504
505
        assert backend_output.dtype == sdpa_output.dtype, (
            f"[{backend_name}] dtype {backend_output.dtype} != "
506
507
            f"SDPA dtype {sdpa_output.dtype}"
        )
508
509

        assert torch.isfinite(backend_output).all(), (
510
511
            f"[{backend_name}] produced non-finite values"
        )
512
513

        # Check numerical similarity
514
        def error_msg(msg: str, backend_name: str):
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
            return f"[{backend_name}] output differs from SDPA baseline. {msg}"

        torch.testing.assert_close(
            backend_output,
            sdpa_output,
            rtol=rtol,
            atol=atol,
            msg=partial(error_msg, backend_name=backend_name),
        )


@pytest.mark.parametrize(
    "batch_spec_name",
    [
        "small_decode",
        "small_prefill",
        "mixed_small",
        "medium_decode",
        "medium_prefill",
        "mixed_medium",
        "large_decode",
        "large_prefill",
        "single_decode",
        "single_prefill",
    ],
)
541
@pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"])
542
543
544
545
@pytest.mark.parametrize("tensor_parallel_size", [1, 2, 4])
def test_causal_backend_correctness(
    batch_spec_name: str, model: str, tensor_parallel_size: int
):
546
547
548
549
550
551
552
553
554
555
556
557
558
    """Test backend's correctness with causal attention."""

    def causal_mask_mod(
        b: torch.Tensor,
        h: torch.Tensor,
        q_idx: torch.Tensor,
        kv_idx: torch.Tensor,
        *,
        context_len: int,
    ):
        return (q_idx + context_len) >= kv_idx

    batch_spec = BATCH_SPECS[batch_spec_name]
559
    LARGE_BLOCK_BACKENDS = (
560
561
562
        [AttentionBackendEnum.FLEX_ATTENTION]
        if is_torch_equal_or_newer("2.9.0.dev0")
        else []
563
    )
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578

    if current_platform.is_rocm():
        SMALL_BLOCK_BACKENDS = [
            x
            for x in BACKENDS_TO_TEST
            if (
                x not in LARGE_BLOCK_BACKENDS
                and x is not AttentionBackendEnum.FLASH_ATTN
            )
        ]
    else:
        SMALL_BLOCK_BACKENDS = [
            x for x in BACKENDS_TO_TEST if x not in LARGE_BLOCK_BACKENDS
        ]

579
580
581
582
583
584
585
    _test_backend_correctness(
        batch_spec,
        model,
        SMALL_BLOCK_BACKENDS,
        causal_mask_mod,
        tensor_parallel_size=tensor_parallel_size,
    )
586
587
588

    # Fast FlexAttention needs to run with block_size=128
    if LARGE_BLOCK_BACKENDS:
589
        _test_backend_correctness(
590
591
592
593
594
595
            batch_spec,
            model,
            LARGE_BLOCK_BACKENDS,
            causal_mask_mod,
            block_size=128,
            tensor_parallel_size=tensor_parallel_size,
596
        )
597
598


599
600
601
602
603
604
605
606
607
608
609
610
611
612
if current_platform.is_rocm():
    # FLASH_ATTN is not supported on ROCm
    SLIDING_WINDOW_BACKENDS_TO_TEST = [
        AttentionBackendEnum.FLEX_ATTENTION,
        AttentionBackendEnum.TRITON_ATTN,
        "FLEX_ATTENTION_SLOW",
    ]
else:
    SLIDING_WINDOW_BACKENDS_TO_TEST = [
        AttentionBackendEnum.FLASH_ATTN,
        AttentionBackendEnum.FLEX_ATTENTION,
        AttentionBackendEnum.TRITON_ATTN,
        "FLEX_ATTENTION_SLOW",
    ]
613
614


615
616
@pytest.mark.parametrize(
    "batch_spec_name",
617
618
619
620
621
622
623
624
    [
        "small_decode",
        "small_prefill",
        "mixed_medium",
        "large_decode",
        "large_prefill",
        "mixed_large",
    ],
625
)
626
@pytest.mark.parametrize("model", ["microsoft/Phi-tiny-MoE-instruct"])
627
628
629
630
@pytest.mark.parametrize("tensor_parallel_size", [1, 2, 4])
def test_sliding_window_backend_correctness(
    batch_spec_name: str, model: str, tensor_parallel_size: int
):
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
    """Test backend's correctness with sliding window attention."""

    def sliding_window_mask_mod(
        b: torch.Tensor,
        h: torch.Tensor,
        q_idx: torch.Tensor,
        kv_idx: torch.Tensor,
        *,
        context_len: int,
        sliding_window: int,
    ):
        causal_mask = q_idx + context_len >= kv_idx
        window_mask = q_idx + context_len - kv_idx < sliding_window
        return causal_mask & window_mask

    batch_spec = BATCH_SPECS[batch_spec_name]
647
    model_config = ModelConfig(model=model, max_model_len=max(batch_spec.seq_lens))
648
    sliding_window = model_config.get_sliding_window()
649
650
651
    sliding_window_mask_mod_fn = partial(
        sliding_window_mask_mod, sliding_window=sliding_window
    )
652

653
    LARGE_BLOCK_BACKENDS = (
654
655
656
        [AttentionBackendEnum.FLEX_ATTENTION]
        if is_torch_equal_or_newer("2.9.0.dev0")
        else []
657
    )
658
    SMALL_BLOCK_BACKENDS = [
659
        x for x in SLIDING_WINDOW_BACKENDS_TO_TEST if x not in LARGE_BLOCK_BACKENDS
660
    ]
661
    _test_backend_correctness(
662
663
664
665
666
        batch_spec,
        model,
        SMALL_BLOCK_BACKENDS,
        sliding_window_mask_mod_fn,
        tensor_parallel_size=tensor_parallel_size,
667
    )
668
669
670

    # Fast FlexAttention needs to run with block_size=128
    if LARGE_BLOCK_BACKENDS:
671
672
673
674
675
676
        _test_backend_correctness(
            batch_spec,
            model,
            LARGE_BLOCK_BACKENDS,
            sliding_window_mask_mod_fn,
            block_size=128,
677
            tensor_parallel_size=tensor_parallel_size,
678
        )