test_attention_backends.py 18.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for v1 attention backends without GPUModelRunner dependency."""

import pytest
import torch

from tests.v1.attention.utils import (BatchSpec, _Backend,
                                      create_common_attn_metadata,
                                      create_standard_kv_cache_spec,
                                      create_vllm_config,
                                      get_attention_backend)
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
14
15
from vllm.v1.attention.backends.utils import (CommonAttentionMetadata,
                                              set_kv_cache_layout)
16
17
18
19
from vllm.v1.kv_cache_interface import FullAttentionSpec

BACKENDS_TO_TEST = [
    _Backend.FLASH_ATTN_VLLM_V1, _Backend.FLASHINFER_VLLM_V1,
20
    _Backend.FLEX_ATTENTION, _Backend.TRITON_ATTN_VLLM_V1, _Backend.TREE_ATTN
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
]

# Remove flashinfer from the list if it's not available
try:
    import flashinfer  # noqa: F401
except ImportError:
    BACKENDS_TO_TEST.remove(_Backend.FLASHINFER_VLLM_V1)


def _convert_dtype_to_torch(dtype):
    """Convert ModelDType to torch.dtype."""
    if isinstance(dtype, str):
        if dtype == "auto":
            return torch.float16  # Default dtype for testing
        elif dtype in STR_DTYPE_TO_TORCH_DTYPE:
            return STR_DTYPE_TO_TORCH_DTYPE[dtype]
        else:
            raise ValueError(f"Unknown dtype: {dtype}")
    elif isinstance(dtype, torch.dtype):
        return dtype
    else:
        raise ValueError(f"Unknown dtype: {dtype}")


# Define common batch configurations
BATCH_SPECS = {
    "small_decode":
    BatchSpec(seq_lens=[32, 40], query_lens=[1, 1]),
    "small_prefill":
    BatchSpec(seq_lens=[32, 40], query_lens=[8, 8]),
    "mixed_small":
    BatchSpec(seq_lens=[32, 40, 48, 56], query_lens=[1, 1, 5, 5]),
    "medium_decode":
    BatchSpec(seq_lens=[128, 256, 512, 1024, 128, 256, 512, 1024],
              query_lens=[1, 1, 1, 1, 1, 1, 1, 1]),
    "medium_prefill":
    BatchSpec(seq_lens=[256, 512, 1024, 2048], query_lens=[16, 16, 16, 16]),
    "mixed_medium":
    BatchSpec(seq_lens=[512, 1024, 2048, 512, 1024, 2048],
              query_lens=[1, 1, 1, 7, 7, 7]),
    "large_decode":
    BatchSpec(seq_lens=[2048] * 32, query_lens=[1] * 32),
    "large_prefill":
    BatchSpec(seq_lens=[4096] * 8, query_lens=[32] * 8),
    "single_decode":
    BatchSpec(seq_lens=[1024], query_lens=[1]),
    "single_prefill":
    BatchSpec(seq_lens=[1024], query_lens=[64]),
}


def create_dummy_kv_cache(kv_cache_spec: FullAttentionSpec,
                          device: torch.device,
                          num_blocks: int = 100) -> torch.Tensor:
    """Create a dummy KV cache tensor for testing."""
    kv_cache = torch.randn(
        2,  # K and V
        num_blocks,
        kv_cache_spec.block_size,
        kv_cache_spec.num_kv_heads,
        kv_cache_spec.head_size,
        dtype=_convert_dtype_to_torch(kv_cache_spec.dtype),
        device=device,
    )
    return kv_cache


def create_and_prepopulate_kv_cache(
        k_contexts: list[torch.Tensor],
        v_contexts: list[torch.Tensor],
        block_size: int,
        num_kv_heads: int,
        head_size: int,
        dtype: torch.dtype,
        device: torch.device,
        num_blocks: int,
        common_attn_metadata: CommonAttentionMetadata,
        randomize_blocks: bool = True) -> torch.Tensor:
    """Create and prepopulate a KV cache with context data.
    
    Args:
        k_contexts: List of key context tensors for each sequence
        v_contexts: List of value context tensors for each sequence
        seq_lens: List of sequence lengths
        block_size: Size of each block
        num_kv_heads: Number of KV heads
        head_size: Size of each head
        dtype: Data type for the cache
        device: Device to create the cache on
        num_blocks: Total number of blocks in the cache
        block_table: Block table tensor to populate
        randomize_blocks: Whether to randomly permute blocks 
                          or use sequential order
        
    Returns:
        Tuple of (kv_cache, updated_block_table)
    """
    batch_size = len(k_contexts)
    seq_lens = common_attn_metadata.seq_lens_cpu
    query_lens = common_attn_metadata.query_start_loc_cpu[
        1:] - common_attn_metadata.query_start_loc_cpu[:-1]
    context_lens = common_attn_metadata.num_computed_tokens_cpu
    block_table = common_attn_metadata.block_table_tensor
    slot_mapping = common_attn_metadata.slot_mapping

    # Create KV cache
    kv_cache = torch.empty(2,
                           num_blocks,
                           block_size,
                           num_kv_heads,
                           head_size,
                           dtype=dtype,
                           device=device)
    kv_cache_flat = kv_cache.view(2, -1, num_kv_heads, head_size)

    # Populate the cache with the context tokens
    # Start from block_id=1 since block_id=0 is considered the null block
    start_block_idx = 1
    for i in range(batch_size):
        k_context, v_context = k_contexts[i], v_contexts[i]
        start = start_block_idx * block_size
        end = start + k_context.shape[0]
        kv_cache_flat[0, start:end, ...] = k_context
        kv_cache_flat[1, start:end, ...] = v_context

        # Stay block aligned and allocate enough blocks for the new tokens
        start_block_idx += cdiv(int(seq_lens[i]), block_size)

    blocks_end = start_block_idx

    # Permute the context blocks (excluding block 0 which is null)
    if randomize_blocks:
Matthew Bonanni's avatar
Matthew Bonanni committed
153
154
        # Random permutation starting from block 1
        perm = torch.randperm(blocks_end - 1) + 1
155
    else:
Matthew Bonanni's avatar
Matthew Bonanni committed
156
157
        # Sequential order starting from block 1
        perm = torch.arange(1, blocks_end)
158
159

    inv_perm = torch.zeros(blocks_end, dtype=torch.long, device=device)
Matthew Bonanni's avatar
Matthew Bonanni committed
160
161
    # Add 1 to account for starting from block 1
    inv_perm[1:] = torch.argsort(perm) + 1
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
    kv_cache[:, 1:blocks_end, ...] = kv_cache[:, perm, ...]

    # Construct the right block table
    # Start from block_id=1 since block_id=0 is considered the null block
    start_block_idx = 1
    for i in range(batch_size):
        num_blocks_for_seq = cdiv(int(seq_lens[i]), block_size)
        start = start_block_idx
        end = start + num_blocks_for_seq
        block_table[i, :num_blocks_for_seq] = inv_perm[start:end]
        start_block_idx += num_blocks_for_seq

        # Create a realistic slot mapping that corresponds to the block table
    for i in range(batch_size):
        token_offsets = torch.arange(int(query_lens[i])) + int(context_lens[i])
        block_indices = token_offsets // block_size
        token_inter_block_offsets = token_offsets % block_size
        start = common_attn_metadata.query_start_loc_cpu[i]
        end = common_attn_metadata.query_start_loc_cpu[i + 1]
        slot_mapping[start:end] = block_table[
            i,
            block_indices] * block_size + token_inter_block_offsets.to(device)

    return kv_cache


class MockAttentionLayer:
    """A mock attention layer for testing."""

    def __init__(self, device: torch.device):
        self._q_scale = torch.tensor(1.0, device=device)
        self._k_scale = torch.tensor(1.0, device=device)
        self._v_scale = torch.tensor(1.0, device=device)
        # Add float versions for flashinfer
        self._k_scale_float = 1.0
        self._v_scale_float = 1.0


def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec,
201
202
                          layer_names: list[str], vllm_config,
                          device: torch.device,
203
204
205
206
207
208
209
210
211
212
213
214
                          common_attn_metadata: CommonAttentionMetadata,
                          query: torch.Tensor, key: torch.Tensor,
                          value: torch.Tensor,
                          kv_cache: torch.Tensor) -> torch.Tensor:
    """Run attention computation using the specified backend's AttentionImpl."""

    builder_cls, impl_cls = get_attention_backend(backend)

    # Mock flashinfer's get_per_layer_parameters if needed
    if backend == _Backend.FLASHINFER_VLLM_V1:
        import unittest.mock

215
        from vllm.v1.attention.backends.utils import PerLayerParameters
216

217
        def mock_get_per_layer_parameters(vllm_config, layer_names, impl_cls):
218
219
220
            # Return mock parameters for a single layer
            head_size = vllm_config.model_config.get_head_size()
            return {
221
                layer_name:
222
223
224
225
226
                PerLayerParameters(
                    window_left=-1,  # No sliding window
                    logits_soft_cap=0.0,  # No soft cap
                    sm_scale=1.0 / (head_size**0.5)  # Standard scale
                )
227
                for layer_name in layer_names
228
229
230
231
232
            }

        with unittest.mock.patch(
                'vllm.v1.attention.backends.flashinfer.get_per_layer_parameters',
                mock_get_per_layer_parameters):
233
234
            builder = builder_cls(kv_cache_spec, layer_names, vllm_config,
                                  device)
235
236
237
238
239
240
            attn_metadata = builder.build(
                common_prefix_len=0,
                common_attn_metadata=common_attn_metadata,
            )
    else:
        # Build metadata
241
        builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device)
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
        attn_metadata = builder.build(
            common_prefix_len=0,
            common_attn_metadata=common_attn_metadata,
        )

    # Instantiate implementation
    num_heads = vllm_config.model_config.get_num_attention_heads(
        vllm_config.parallel_config)
    num_kv_heads = vllm_config.model_config.get_num_kv_heads(
        vllm_config.parallel_config)
    head_size = vllm_config.model_config.get_head_size()
    scale = 1.0 / (head_size**0.5)
    impl = impl_cls(
        num_heads=num_heads,
        head_size=head_size,
        scale=scale,
        num_kv_heads=num_kv_heads,
        alibi_slopes=None,
        sliding_window=None,
        kv_cache_dtype="auto",
    )

    # Create mock layer and output buffer
    mock_layer = MockAttentionLayer(device)
    output = torch.empty_like(query)

    # Run forward pass
    # NOTE: The query, key, and value are already shaped correctly
    # in the calling test function.
    output = impl.forward(mock_layer,
                          query,
                          key,
                          value,
                          kv_cache,
                          attn_metadata,
                          output=output)

    return output


@pytest.mark.parametrize("batch_spec_name", [
    "small_decode", "small_prefill", "mixed_small", "medium_decode",
Matthew Bonanni's avatar
Matthew Bonanni committed
284
285
    "medium_prefill", "mixed_medium", "large_decode", "large_prefill",
    "single_decode", "single_prefill"
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
])
@pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"])
def test_backend_correctness(batch_spec_name: str, model: str):
    """
    Test that all backends produce similar outputs to a reference implementation
    using torch.nn.functional.scaled_dot_product_attention.

    This test works by:
    1. Generating a batch of sequences with specified context and query lengths.
    2. Computing a ground-truth attention output using torch.sdpa on
       contiguous Q, K, and V tensors.
    3. Simulating vLLM's paged KV cache: It takes the context portion of the
       K/V tensors and manually places them into a paged buffer according to
       the test's (randomly generated) block table.
    4. Running each vLLM attention backend with the new queries and the
       simulated paged KV cache.
    5. Comparing the vLLM backend's output to the ground-truth SDPA output.
    """
    batch_spec = BATCH_SPECS[batch_spec_name]
305
    vllm_config = create_vllm_config(model_name=model,
Matthew Bonanni's avatar
Matthew Bonanni committed
306
307
                                     max_model_len=max(batch_spec.seq_lens),
                                     num_gpu_blocks=8192)
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
    device = torch.device("cuda:0")

    kv_cache_spec = create_standard_kv_cache_spec(vllm_config)

    # 1. Setup
    batch_size = batch_spec.batch_size
    seq_lens = batch_spec.seq_lens
    query_lens = batch_spec.query_lens
    num_q_heads = vllm_config.model_config.get_num_attention_heads(
        vllm_config.parallel_config)
    num_kv_heads = vllm_config.model_config.get_num_kv_heads(
        vllm_config.parallel_config)
    head_size = vllm_config.model_config.get_head_size()
    dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype)
    block_size = vllm_config.cache_config.block_size
    scale = 1.0 / (head_size**0.5)

    # 2. Generate data and compute SDPA reference output
    all_q_vllm, all_k_vllm, all_v_vllm = [], [], []
    all_sdpa_outputs = []
    k_contexts, v_contexts = [], []

    for i in range(batch_size):
        s_len = seq_lens[i]
        q_len = query_lens[i]
        context_len = s_len - q_len

        # Generate Q, K, V for the whole sequence to be used in SDPA
        q = torch.randn(q_len,
                        num_q_heads,
                        head_size,
                        dtype=dtype,
                        device=device)
        k_full = torch.randn(s_len,
                             num_kv_heads,
                             head_size,
                             dtype=dtype,
                             device=device)
        v_full = torch.randn(s_len,
                             num_kv_heads,
                             head_size,
                             dtype=dtype,
                             device=device)

        # SDPA expects (N, H, L, D), so unsqueeze batch and permute
        q_sdpa_in = q.unsqueeze(0).transpose(1, 2)
        k_sdpa_in = k_full.unsqueeze(0).transpose(1, 2)
        v_sdpa_in = v_full.unsqueeze(0).transpose(1, 2)

        if num_q_heads != num_kv_heads:
            assert num_q_heads % num_kv_heads == 0, (
                f"num_q_heads ({num_q_heads}) must be divisible by "
                f"num_kv_heads ({num_kv_heads})")
            repeats = num_q_heads // num_kv_heads
            k_sdpa_in = k_sdpa_in.repeat_interleave(repeats, dim=1)
            v_sdpa_in = v_sdpa_in.repeat_interleave(repeats, dim=1)

        # Create causal mask: query token i attends to positions 0 to
        #  (context_len + i)
        kv_len = s_len
        offset = context_len
        attn_mask = torch.full((q_len, kv_len),
                               float('-inf'),
                               device=device,
                               dtype=dtype)
        for i in range(q_len):
            attn_mask[i, :offset + i + 1] = 0.0

        sdpa_out_i = torch.nn.functional.scaled_dot_product_attention(
            q_sdpa_in,
            k_sdpa_in,
            v_sdpa_in,
            attn_mask=attn_mask,
            scale=scale,
            enable_gqa=True)
        # Convert back to (L, H, D)
        all_sdpa_outputs.append(sdpa_out_i.transpose(1, 2).squeeze(0))

        # Inputs for vLLM backends are just the new tokens
        all_q_vllm.append(q)
        all_k_vllm.append(k_full[context_len:])
        all_v_vllm.append(v_full[context_len:])

        # Contextual K/V data used to populate the paged cache
        k_contexts.append(k_full[:context_len])
        v_contexts.append(v_full[:context_len])

    query_vllm = torch.cat(all_q_vllm, dim=0)
    key_vllm = torch.cat(all_k_vllm, dim=0)
    value_vllm = torch.cat(all_v_vllm, dim=0)
    sdpa_output = torch.cat(all_sdpa_outputs, dim=0)

    common_attn_metadata = create_common_attn_metadata(
        batch_spec, vllm_config.cache_config.block_size, device)

    # 3. Simulate Paged KV Cache and a realistic slot_mapping
    kv_cache = create_and_prepopulate_kv_cache(
        k_contexts=k_contexts,
        v_contexts=v_contexts,
        block_size=block_size,
        num_kv_heads=num_kv_heads,
        head_size=head_size,
        dtype=dtype,
        device=device,
        num_blocks=vllm_config.cache_config.num_gpu_blocks or 1000,
        common_attn_metadata=common_attn_metadata,
        randomize_blocks=True)

    # 4. Run vLLM backends and compare
    # Note: flex_attention has known Triton kernel compatibility issues
    # with test infrastructures
    for backend_name in BACKENDS_TO_TEST:
        # FlashAttentionm + FlexAttention:
        #   [2, num_blocks, block_size, num_kv_heads, head_size]
        # FlashInfer:
        #   [num_blocks, 2, block_size, num_kv_heads, head_size]
        # Select the appropriate KV cache format for each backend
        kv_cache_for_backend = kv_cache
        if backend_name == _Backend.FLASHINFER_VLLM_V1:
            kv_cache_for_backend = kv_cache.transpose(0, 1)

429
430
431
432
433
            # For FlashInfer default to HND layout and
            kv_cache_for_backend = kv_cache_for_backend.transpose(
                2, 3).contiguous().transpose(2, 3)
            set_kv_cache_layout("HND")

434
        backend_output = run_attention_backend(backend_name, kv_cache_spec,
435
436
                                               ["placeholder"], vllm_config,
                                               device, common_attn_metadata,
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
                                               query_vllm, key_vllm,
                                               value_vllm,
                                               kv_cache_for_backend)

        # Check shape and dtype consistency
        assert backend_output.shape == sdpa_output.shape, (
            f"[{backend_name}] shape {backend_output.shape} != "
            f"SDPA shape {sdpa_output.shape}")
        assert backend_output.dtype == sdpa_output.dtype, (
            f"[{backend_name}] dtype {backend_output.dtype} != "
            f"SDPA dtype {sdpa_output.dtype}")

        assert torch.isfinite(backend_output).all(), (
            f"[{backend_name}] produced non-finite values")

        # Check numerical similarity
        rtol = 1e-2
        atol = 5e-3

        if backend_name == _Backend.FLEX_ATTENTION:
            atol = 5e-1  # TODO: figure out why flex_attention has such large
            # numerical differences for medium_decode, medium_prefill,
            # mixed_medium

        max_diff = torch.max(torch.abs(backend_output - sdpa_output)).item()
        max_rel_diff = torch.max(
            torch.abs(backend_output - sdpa_output) /
            torch.abs(sdpa_output)).item()
        all_close = torch.allclose(backend_output,
                                   sdpa_output,
                                   rtol=rtol,
                                   atol=atol)

        assert all_close, (
            f"[{backend_name}] output differs from SDPA baseline. "
Matthew Bonanni's avatar
Matthew Bonanni committed
472
            f"Max diff: {max_diff:.6f}, max rel diff: {max_rel_diff:.6f})")