test_mla_backends.py 42.6 KB
Newer Older
Matthew Bonanni's avatar
Matthew Bonanni committed
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
"""Tests for v1 MLA backends without GPUModelRunner dependency.

Known Issues:
- FLASH_ATTN_MLA backend occasionally produces NaN values in
  test_backend_correctness[mixed_small] when run after
  test_backend_correctness[small_prefill], but passes when run alone.
"""
10

Matthew Bonanni's avatar
Matthew Bonanni committed
11
12
13
import pytest
import torch

14
15
16
17
from tests.v1.attention.utils import (
    BatchSpec,
    create_common_attn_metadata,
    create_vllm_config,
18
    try_get_attention_backend,
19
)
20
from vllm import _custom_ops as ops
21
from vllm.config.vllm import set_current_vllm_config
22
23
24
25
from vllm.model_executor.layers.attention.mla_attention import (
    QueryLenSupport,
    _DecodeConcatQuantFP8,
)
26
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
27
28
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.platforms import current_platform
29
from vllm.utils.math_utils import cdiv
30
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
31
from vllm.v1.attention.backend import CommonAttentionMetadata
32
33
34
from vllm.v1.attention.backends.fa_utils import flash_attn_supports_mla
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.attention.ops.flashmla import is_flashmla_dense_supported
35
from vllm.v1.kv_cache_interface import MLAAttentionSpec
Matthew Bonanni's avatar
Matthew Bonanni committed
36
37

BACKENDS_TO_TEST = [
38
39
40
41
42
    AttentionBackendEnum.CUTLASS_MLA,
    AttentionBackendEnum.FLASHMLA,
    AttentionBackendEnum.FLASH_ATTN_MLA,
    AttentionBackendEnum.FLASHINFER_MLA,
    AttentionBackendEnum.TRITON_MLA,
Matthew Bonanni's avatar
Matthew Bonanni committed
43
44
]

45
46
DEVICE_TYPE = current_platform.device_type

47
# Remove sm100 backends from the list if not using sm100
48
if not torch.cuda.is_available() or torch.cuda.get_device_properties(0).major < 10:
49
50
    BACKENDS_TO_TEST.remove(AttentionBackendEnum.CUTLASS_MLA)
    BACKENDS_TO_TEST.remove(AttentionBackendEnum.FLASHINFER_MLA)
51
52
53

# Remove FLASH_ATTN_MLA from the list if not supported
if not flash_attn_supports_mla():
54
    BACKENDS_TO_TEST.remove(AttentionBackendEnum.FLASH_ATTN_MLA)
Matthew Bonanni's avatar
Matthew Bonanni committed
55

56
57
# Remove FLASHMLA from the list if not supported
if not is_flashmla_dense_supported()[0]:
58
    BACKENDS_TO_TEST.remove(AttentionBackendEnum.FLASHMLA)
59

60

61
62
63
64
65
66
67
68
69
70
71
SPEC_DECODE_BACKENDS = []
for backend in BACKENDS_TO_TEST:
    builder_cls, _ = try_get_attention_backend(backend)
    query_len_support = getattr(
        builder_cls, "query_len_support", QueryLenSupport.SINGLE_ONLY
    )
    if query_len_support != QueryLenSupport.SINGLE_ONLY:
        SPEC_DECODE_BACKENDS.append(backend)

BACKEND_BLOCK_SIZES = {}
for backend in BACKENDS_TO_TEST:
72
    supported_sizes = backend.get_class().get_supported_kernel_block_sizes()
73
74
75
76
77
78
79
80
81
    if supported_sizes:
        default_size = supported_sizes[0]
        block_size = (
            default_size if isinstance(default_size, int) else default_size.base
        )
    else:
        block_size = 16
    BACKEND_BLOCK_SIZES[backend] = block_size

Matthew Bonanni's avatar
Matthew Bonanni committed
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
torch.manual_seed(42)


def _convert_dtype_to_torch(dtype):
    """Convert ModelDType to torch.dtype."""
    if isinstance(dtype, str):
        if dtype == "auto":
            return torch.float16  # Default dtype for testing
        elif dtype in STR_DTYPE_TO_TORCH_DTYPE:
            return STR_DTYPE_TO_TORCH_DTYPE[dtype]
        else:
            raise ValueError(f"Unknown dtype: {dtype}")
    elif isinstance(dtype, torch.dtype):
        return dtype
    else:
        raise ValueError(f"Unknown dtype: {dtype}")


# Define common batch configurations
BATCH_SPECS = {
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
    "small_decode": BatchSpec(seq_lens=[32, 40], query_lens=[1, 1]),
    "small_prefill": BatchSpec(seq_lens=[32, 40], query_lens=[8, 8]),
    "mixed_small": BatchSpec(seq_lens=[32, 40, 48, 56], query_lens=[1, 1, 5, 5]),
    "medium_decode": BatchSpec(
        seq_lens=[128, 256, 512, 1024, 128, 256, 512, 1024],
        query_lens=[1, 1, 1, 1, 1, 1, 1, 1],
    ),
    "medium_prefill": BatchSpec(
        seq_lens=[256, 512, 1024, 2048], query_lens=[16, 16, 16, 16]
    ),
    "mixed_medium": BatchSpec(
        seq_lens=[512, 1024, 2048, 512, 1024, 2048], query_lens=[1, 1, 1, 7, 7, 7]
    ),
    "large_decode": BatchSpec(seq_lens=[2048] * 32, query_lens=[1] * 32),
    "large_prefill": BatchSpec(seq_lens=[4096] * 8, query_lens=[32] * 8),
    "single_decode": BatchSpec(seq_lens=[1024], query_lens=[1]),
    "single_prefill": BatchSpec(seq_lens=[1024], query_lens=[64]),
119
120
121
122
123
124
    "spec_decode_small": BatchSpec(
        seq_lens=[128, 256, 512, 1024], query_lens=[4, 4, 4, 4]
    ),
    "spec_decode_medium": BatchSpec(
        seq_lens=[512, 1024, 2048, 512, 1024, 2048], query_lens=[8, 8, 8, 8, 8, 8]
    ),
Matthew Bonanni's avatar
Matthew Bonanni committed
125
126
127
128
}


def create_and_prepopulate_kv_cache(
129
130
131
132
133
134
135
136
137
    kv_c_contexts: list[torch.Tensor],
    k_pe_contexts: list[torch.Tensor],
    block_size: int,
    head_size: int,
    dtype: torch.dtype,
    device: torch.device,
    num_blocks: int,
    common_attn_metadata: CommonAttentionMetadata,
    randomize_blocks: bool = True,
138
139
    kv_cache_dtype: str | None = None,
    scale: float | torch.Tensor = 1.0,
140
) -> torch.Tensor:
Matthew Bonanni's avatar
Matthew Bonanni committed
141
    """Create and prepopulate an MLA KV cache with context data.
142

Matthew Bonanni's avatar
Matthew Bonanni committed
143
144
145
146
147
148
149
150
151
152
    Args:
        kv_c_contexts: List of latent KV context tensors for each sequence
        k_pe_contexts: List of key positional embedding context tensors
                       for each sequence
        block_size: Size of each block
        head_size: Size of each head (latent dimension)
        dtype: Data type for the cache
        device: Device to create the cache on
        num_blocks: Total number of blocks in the cache
        common_attn_metadata: Common attention metadata
153
        randomize_blocks: Whether to randomly permute blocks
Matthew Bonanni's avatar
Matthew Bonanni committed
154
                          or use sequential order
155
156
        kv_cache_dtype: Optional kv cache dtype string. For fp8 cache dtype,
                        the cache is populated via concat_and_cache_mla.
157
158
        scale: Scaling factor forwarded to concat_and_cache_mla when the
               fp8 cache layout is requested.
159

Matthew Bonanni's avatar
Matthew Bonanni committed
160
161
162
163
    Returns:
        MLA KV cache tensor
    """
    batch_size = len(kv_c_contexts)
164
    seq_lens = common_attn_metadata.seq_lens.cpu()
165
166
167
168
    query_lens = (
        common_attn_metadata.query_start_loc_cpu[1:]
        - common_attn_metadata.query_start_loc_cpu[:-1]
    )
169
    context_lens = seq_lens - query_lens
Matthew Bonanni's avatar
Matthew Bonanni committed
170
171
172
    block_table = common_attn_metadata.block_table_tensor
    slot_mapping = common_attn_metadata.slot_mapping

173
    fp8_attention = kv_cache_dtype and kv_cache_dtype.startswith("fp8")
174
175
    use_fp8_ds_mla = kv_cache_dtype == "fp8_ds_mla"

176
177
178
179
180
181
182
183
184
185
    if fp8_attention:
        if use_fp8_ds_mla:
            kv_lora_rank = kv_c_contexts[0].shape[-1]
            rope_dim = k_pe_contexts[0].shape[-1]
            # 4 * 4: 4 float32 scale values for 128-element tiles
            # 2 * rope_dim: 16-bit RoPE values
            kv_entry_size = kv_lora_rank + 4 * 4 + 2 * rope_dim
        else:
            kv_entry_size = head_size

186
        kv_cache = torch.zeros(
187
            num_blocks, block_size, kv_entry_size, dtype=torch.uint8, device=device
188
189
190
191
192
193
        )
        scale_tensor = (
            scale
            if isinstance(scale, torch.Tensor)
            else torch.tensor(scale, dtype=torch.float32, device=device)
        )
194
195
196
        scale_tensor = scale_tensor.to(device=device, dtype=torch.float32)
    else:
        # Create MLA KV cache: (num_blocks, block_size, head_size)
197
        kv_cache = torch.zeros(
198
199
            num_blocks, block_size, head_size, dtype=dtype, device=device
        )
200
        kv_cache_flat = kv_cache.view(-1, head_size)
Matthew Bonanni's avatar
Matthew Bonanni committed
201
202
203
204
205
206

    # Populate the cache with the context tokens
    # Start from block_id=1 since block_id=0 is considered the null block
    start_block_idx = 1
    for i in range(batch_size):
        kv_c_context, k_pe_context = kv_c_contexts[i], k_pe_contexts[i]
207
208
209
210
211
        context_len = kv_c_context.shape[0]
        if context_len == 0:
            start_block_idx += cdiv(int(seq_lens[i]), block_size)
            continue

Matthew Bonanni's avatar
Matthew Bonanni committed
212
        start = start_block_idx * block_size
213

214
        if fp8_attention:
215
            slots = torch.arange(context_len, device=device, dtype=torch.long) + start
216
217
218
219
220
            ops.concat_and_cache_mla(
                kv_c_context,
                k_pe_context.squeeze(1),
                kv_cache,
                slots,
221
                kv_cache_dtype=kv_cache_dtype,
222
223
224
                scale=scale_tensor,
            )
        else:
225
            kv_context = torch.cat([kv_c_context, k_pe_context.squeeze(1)], dim=-1)
226
227
            end = start + kv_context.shape[0]
            kv_cache_flat[start:end, ...] = kv_context
Matthew Bonanni's avatar
Matthew Bonanni committed
228
229
230
231
232
233
234
235

        # Stay block aligned and allocate enough blocks for the new tokens
        start_block_idx += cdiv(int(seq_lens[i]), block_size)

    blocks_end = start_block_idx

    # Permute the context blocks (excluding block 0 which is null)
    if randomize_blocks:
236
237
238
        perm = (
            torch.randperm(blocks_end - 1) + 1
        )  # Random permutation starting from block 1
Matthew Bonanni's avatar
Matthew Bonanni committed
239
    else:
240
        perm = torch.arange(1, blocks_end)  # Sequential order starting from block 1
Matthew Bonanni's avatar
Matthew Bonanni committed
241
242

    inv_perm = torch.zeros(blocks_end, dtype=torch.long, device=device)
243
    inv_perm[1:] = torch.argsort(perm) + 1  # Add 1 to account for starting from block 1
Matthew Bonanni's avatar
Matthew Bonanni committed
244
245
246
247
248
249
250
251
252
253
    kv_cache[1:blocks_end, ...] = kv_cache[perm, ...]

    # Construct the right block table
    # Start from block_id=1 since block_id=0 is considered the null block
    start_block_idx = 1
    for i in range(batch_size):
        num_blocks_for_seq = cdiv(int(seq_lens[i]), block_size)
        start = start_block_idx
        end = start + num_blocks_for_seq
        block_table[i, :num_blocks_for_seq] = inv_perm[start:end]
254
        block_table[i, num_blocks_for_seq:] = 0
Matthew Bonanni's avatar
Matthew Bonanni committed
255
256
257
258
259
260
261
262
263
264
        start_block_idx += num_blocks_for_seq

        # Create a realistic slot mapping that corresponds to the block table
    for i in range(batch_size):
        token_offsets = torch.arange(int(query_lens[i])) + int(context_lens[i])
        block_indices = token_offsets // block_size
        token_inter_block_offsets = token_offsets % block_size
        start = common_attn_metadata.query_start_loc_cpu[i]
        end = common_attn_metadata.query_start_loc_cpu[i + 1]
        slot_mapping[start:end] = block_table[
265
266
            i, block_indices
        ] * block_size + token_inter_block_offsets.to(device)
Matthew Bonanni's avatar
Matthew Bonanni committed
267
268
269
270

    return kv_cache


271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
class MockSparseMLAAttentionLayer:
    """A mock sparse MLA attention layer for testing.

    Sparse MLA implementations only support forward_mqa (decode-style attention)
    for all tokens, so this class only implements that path.

    Unlike regular MLA impls, sparse MLA impls don't have W_UK_T and W_UV
    attributes. These transformations are done by the layer (MLAAttention),
    not the impl. This mock layer accepts these weight matrices directly.
    """

    def __init__(
        self,
        impl,
        num_heads: int,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        v_head_dim: int,
        kv_lora_rank: int,
        device: torch.device,
        W_UK: torch.Tensor,
        W_UV: torch.Tensor,
293
294
        q_scale: float,
        k_scale: float,
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
    ):
        self.impl = impl
        self.num_heads = num_heads
        self.qk_nope_head_dim = qk_nope_head_dim
        self.qk_rope_head_dim = qk_rope_head_dim
        self.v_head_dim = v_head_dim
        self.kv_lora_rank = kv_lora_rank

        # Compute weight matrices in the format expected by forward_impl
        # W_UK shape: (L, N, P) -> W_UK_T shape: (N, P, L)
        self.W_UK_T = W_UK.permute(1, 2, 0)
        # W_UV shape: (L, N, V) -> (N, L, V)
        self.W_UV = W_UV.transpose(0, 1)

        # Scale attributes needed by attention backends
310
311
312
        self._q_scale = torch.tensor(q_scale, device=device)
        self._k_scale = torch.tensor(k_scale, device=device)
        self._v_scale = torch.tensor(float("nan"), device=device)
313
        self._prob_scale = torch.tensor(1.0, device=device)
314
315
316
        self._q_scale_float = q_scale
        self._k_scale_float = k_scale
        self._v_scale_float = float("nan")
317

318
319
320
321
322
323
        self._decode_concat_quant_fp8_op = _DecodeConcatQuantFP8(
            static=True,
            group_shape=GroupShape.PER_TENSOR,
            compile_native=True,
        )

324
325
326
327
328
329
330
331
332
333
334
    def forward_impl(
        self,
        q: torch.Tensor,
        kv_c: torch.Tensor,
        k_pe: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata,
        output: torch.Tensor,
    ) -> torch.Tensor:
        """Forward for sparse MLA - uses forward_mqa for all tokens."""
        kv_cache_dtype = getattr(self.impl, "kv_cache_dtype", "auto")
335
        fp8_attention = kv_cache_dtype.startswith("fp8")
336
337

        # Write to KV cache
338
339
340
341
342
343
344
345
346
347
        if kv_cache.numel() > 0:
            ops.concat_and_cache_mla(
                kv_c,
                k_pe.squeeze(1),
                kv_cache,
                attn_metadata.slot_mapping.flatten(),
                kv_cache_dtype=kv_cache_dtype,
                scale=self._k_scale,
            )

348
349
350
        if fp8_attention and kv_cache_dtype != "fp8_ds_mla":
            kv_cache = kv_cache.view(current_platform.fp8_dtype())

351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
        num_tokens = q.shape[0]

        # Sparse MLA uses forward_mqa for all tokens
        # Split q into nope and pe parts
        mqa_q_nope, mqa_q_pe = q.split(
            [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
        )

        # Convert from (B, N, P) to (N, B, P)
        mqa_q_nope = mqa_q_nope.transpose(0, 1)

        # Multiply (N, B, P) x (N, P, L) -> (N, B, L)
        mqa_ql_nope = torch.bmm(mqa_q_nope, self.W_UK_T)

        # Convert from (N, B, L) to (B, N, L)
        mqa_ql_nope = mqa_ql_nope.transpose(0, 1)

368
369
370
371
372
373
374
375
        if fp8_attention and self.impl.supports_quant_query_input:
            assert mqa_ql_nope.shape[0] == mqa_q_pe.shape[0]
            assert mqa_ql_nope.shape[1] == mqa_q_pe.shape[1]
            mqa_q = self._decode_concat_quant_fp8_op(
                mqa_ql_nope, mqa_q_pe, self._q_scale
            )
        else:
            mqa_q = (mqa_ql_nope, mqa_q_pe)
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390

        attn_out, _ = self.impl.forward_mqa(mqa_q, kv_cache, attn_metadata, self)

        # v_up projection: multiply by W_UV
        # attn_out shape: (B, N, L) where L = kv_lora_rank
        # W_UV shape: (N, L, V)
        # output shape: (B, N, V) -> flatten to (B, N*V)
        decode_output = torch.bmm(attn_out.transpose(0, 1), self.W_UV).transpose(0, 1)
        output[:num_tokens] = decode_output.reshape(
            num_tokens, self.num_heads * self.v_head_dim
        )

        return output


391
class MockMLAAttentionLayer(AttentionLayerBase):
392
    """A mock MLA attention layer for testing.
393

394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
    This replicates the forward_impl logic from MLAAttention to allow
    testing MLA backends without the full layer infrastructure.

    The W_UK_T and W_UV weight matrices are created on the layer (like in
    MLAAttention.process_weights_after_loading), not on the impl.
    """

    def __init__(
        self,
        impl,
        num_heads: int,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        v_head_dim: int,
        kv_lora_rank: int,
        device: torch.device,
        kv_b_proj,
411
412
        q_scale: float,
        k_scale: float,
413
    ):
414
        self.impl = impl
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
        self.num_heads = num_heads
        self.qk_nope_head_dim = qk_nope_head_dim
        self.qk_rope_head_dim = qk_rope_head_dim
        self.v_head_dim = v_head_dim
        self.kv_lora_rank = kv_lora_rank

        # Compute weight matrices from kv_b_proj (like MLAAttention does)
        # This replicates MLAAttention.process_weights_after_loading logic
        kv_b_proj_weight = kv_b_proj.weight.T
        kv_b_proj_weight = kv_b_proj_weight.view(
            kv_lora_rank,
            num_heads,
            qk_nope_head_dim + v_head_dim,
        )
        W_UK, W_UV = kv_b_proj_weight.split([qk_nope_head_dim, v_head_dim], dim=-1)
        # Convert from (L, N, V) to (N, L, V)
        self.W_UV = W_UV.transpose(0, 1)
        # Convert from (L, N, P) to (N, P, L)
        self.W_UK_T = W_UK.permute(1, 2, 0)

        # Scale attributes needed by attention backends
436
437
438
        self._q_scale = torch.tensor(q_scale, device=device)
        self._k_scale = torch.tensor(k_scale, device=device)
        self._v_scale = torch.tensor(float("nan"), device=device)
439
        self._prob_scale = torch.tensor(1.0, device=device)
440
441
442
        self._q_scale_float = q_scale
        self._k_scale_float = k_scale
        self._v_scale_float = float("nan")
443

444
445
446
447
448
449
        self._decode_concat_quant_fp8_op = _DecodeConcatQuantFP8(
            static=True,
            group_shape=GroupShape.PER_TENSOR,
            compile_native=True,
        )

450
451
452
453
454
    def get_attn_backend(self):
        raise NotImplementedError

    def get_kv_cache_spec(self, vllm_config):
        raise NotImplementedError
Matthew Bonanni's avatar
Matthew Bonanni committed
455

456
457
458
459
460
461
462
463
464
465
466
    def forward_impl(
        self,
        q: torch.Tensor,
        kv_c: torch.Tensor,
        k_pe: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata,
        output: torch.Tensor,
    ) -> torch.Tensor:
        """Replicates MLAAttention.forward_impl logic for testing."""
        # Write to KV cache
467
468
        kv_cache_dtype = getattr(self.impl, "kv_cache_dtype", "auto")
        fp8_attention = kv_cache_dtype.startswith("fp8")
469
470
471
472
473
474
        if kv_cache.numel() > 0:
            ops.concat_and_cache_mla(
                kv_c,
                k_pe.squeeze(1),
                kv_cache,
                attn_metadata.slot_mapping.flatten(),
475
                kv_cache_dtype=kv_cache_dtype,
476
477
478
                scale=self._k_scale,
            )

479
480
481
        if fp8_attention and kv_cache_dtype != "fp8_ds_mla":
            kv_cache = kv_cache.view(current_platform.fp8_dtype())

482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
        # Determine decode vs prefill split
        num_decode_tokens = attn_metadata.num_decode_tokens or 0
        has_decode = (attn_metadata.num_decodes or 0) > 0
        has_prefill = (attn_metadata.num_prefills or 0) > 0

        # Run prefill with forward_mha
        if has_prefill:
            prefill_q = q[num_decode_tokens:]
            prefill_k_pe = k_pe[num_decode_tokens:]
            prefill_k_c = kv_c[num_decode_tokens:]
            self.impl.forward_mha(
                prefill_q,
                prefill_k_c,
                prefill_k_pe,
                kv_cache,
                attn_metadata,
                self._k_scale,
                output=output[num_decode_tokens:],
            )

        # Run decode with forward_mqa
        if has_decode:
            decode_q = q[:num_decode_tokens]

            # Split q into nope and pe parts
            mqa_q_nope, mqa_q_pe = decode_q.split(
                [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
            )

            # Convert from (B, N, P) to (N, B, P)
            mqa_q_nope = mqa_q_nope.transpose(0, 1)

            # Multiply (N, B, P) x (N, P, L) -> (N, B, L)
            mqa_ql_nope = torch.bmm(mqa_q_nope, self.W_UK_T)

            # Convert from (N, B, L) to (B, N, L)
            mqa_ql_nope = mqa_ql_nope.transpose(0, 1)

520
521
522
523
524
525
526
527
            if fp8_attention and self.impl.supports_quant_query_input:
                assert mqa_ql_nope.shape[0] == mqa_q_pe.shape[0]
                assert mqa_ql_nope.shape[1] == mqa_q_pe.shape[1]
                mqa_q = self._decode_concat_quant_fp8_op(
                    mqa_ql_nope, mqa_q_pe, self._q_scale
                )
            else:
                mqa_q = (mqa_ql_nope, mqa_q_pe)
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543

            attn_out, _ = self.impl.forward_mqa(mqa_q, kv_cache, attn_metadata, self)

            # v_up projection: multiply by W_UV
            # attn_out shape: (B, N, L) where L = kv_lora_rank
            # W_UV shape: (N, L, V)
            # output shape: (B, N, V) -> flatten to (B, N*V)
            decode_output = torch.bmm(attn_out.transpose(0, 1), self.W_UV).transpose(
                0, 1
            )
            output[:num_decode_tokens] = decode_output.reshape(
                num_decode_tokens, self.num_heads * self.v_head_dim
            )

        return output

Matthew Bonanni's avatar
Matthew Bonanni committed
544

545
def run_attention_backend(
546
    backend: AttentionBackendEnum,
547
    kv_cache_spec: MLAAttentionSpec,
548
549
550
551
552
553
554
555
556
557
558
559
560
    layer_names: list[str],
    vllm_config,
    device: torch.device,
    common_attn_metadata: CommonAttentionMetadata,
    query: torch.Tensor,
    kv_c: torch.Tensor,
    k_pe: torch.Tensor,
    kv_cache: torch.Tensor,
    kv_lora_rank: int,
    qk_nope_head_dim: int,
    qk_rope_head_dim: int,
    v_head_dim: int,
    mock_kv_b_proj,
561
562
    q_scale: float,
    k_scale: float,
563
    kv_cache_dtype: str = "auto",
564
) -> torch.Tensor:
Matthew Bonanni's avatar
Matthew Bonanni committed
565
566
    """Run attention computation using the specified backend's AttentionImpl."""

567
    builder_cls, impl_cls = try_get_attention_backend(backend)
Matthew Bonanni's avatar
Matthew Bonanni committed
568

569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
    # Set the current vllm config so that get_current_vllm_config() works
    # in the backend implementations
    with set_current_vllm_config(vllm_config):
        # Instantiate MLA implementation
        num_heads = vllm_config.model_config.get_num_attention_heads(
            vllm_config.parallel_config
        )
        num_kv_heads = vllm_config.model_config.get_num_kv_heads(
            vllm_config.parallel_config
        )
        head_size = vllm_config.model_config.get_head_size()
        scale = 1.0 / (head_size**0.5)
        impl = impl_cls(
            num_heads=num_heads,
            head_size=head_size,
            scale=scale,
            num_kv_heads=num_kv_heads,
            alibi_slopes=None,
            sliding_window=None,
588
            kv_cache_dtype=kv_cache_dtype,
589
590
591
592
593
594
595
596
597
598
599
            logits_soft_cap=None,
            attn_type="decoder",
            kv_sharing_target_layer_name=None,
            q_lora_rank=None,
            kv_lora_rank=kv_lora_rank,
            qk_nope_head_dim=qk_nope_head_dim,
            qk_rope_head_dim=qk_rope_head_dim,
            qk_head_dim=qk_nope_head_dim + qk_rope_head_dim,
            v_head_dim=v_head_dim,
            kv_b_proj=mock_kv_b_proj,
        )
Matthew Bonanni's avatar
Matthew Bonanni committed
600

601
        # Process weights on the impl
602
603
        act_dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype)
        impl.process_weights_after_loading(act_dtype)
Matthew Bonanni's avatar
Matthew Bonanni committed
604

605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
        # Initialize DCP attributes (normally set by MLAAttention.forward
        # before calling forward_mha, see mla_attention.py:511-512)
        if impl.dcp_world_size == -1:
            impl.dcp_world_size = 1

        # Create mock MLA layer
        mock_layer = MockMLAAttentionLayer(
            impl=impl,
            num_heads=num_heads,
            qk_nope_head_dim=qk_nope_head_dim,
            qk_rope_head_dim=qk_rope_head_dim,
            v_head_dim=v_head_dim,
            kv_lora_rank=kv_lora_rank,
            device=device,
            kv_b_proj=mock_kv_b_proj,
620
621
            q_scale=q_scale,
            k_scale=k_scale,
622
623
        )

624
625
626
        # Populate static_forward_context with mock attention layers
        for layer_name in layer_names:
            vllm_config.compilation_config.static_forward_context[layer_name] = (
627
                mock_layer
628
629
630
631
632
633
634
635
636
            )

        # Build metadata
        builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device)
        attn_metadata = builder.build(
            common_prefix_len=0,
            common_attn_metadata=common_attn_metadata,
        )

637
        # Create output buffer
638
639
640
641
        num_tokens = query.shape[0]
        output = torch.empty(
            num_tokens, num_heads * v_head_dim, dtype=query.dtype, device=query.device
        )
Matthew Bonanni's avatar
Matthew Bonanni committed
642

643
        # Run forward pass
644
645
        output = mock_layer.forward_impl(
            query, kv_c, k_pe, kv_cache, attn_metadata, output
646
        )
Matthew Bonanni's avatar
Matthew Bonanni committed
647

648
        return output
Matthew Bonanni's avatar
Matthew Bonanni committed
649
650


651
652
653
654
655
656
657
658
659
660
661
662
663
@pytest.mark.parametrize(
    "batch_spec_name",
    [
        "small_decode",
        "small_prefill",
        "mixed_small",
        "medium_decode",
        "medium_prefill",
        "mixed_medium",
        "large_decode",
        "large_prefill",
        "single_decode",
        "single_prefill",
664
665
        "spec_decode_small",
        "spec_decode_medium",
666
667
    ],
)
668
669
@pytest.mark.parametrize("model", ["deepseek-ai/DeepSeek-R1"])
@pytest.mark.parametrize("tensor_parallel_size", [1, 4, 8, 16])
670
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8", "fp8_e4m3"])
671
@pytest.mark.parametrize(("q_scale", "k_scale"), [(1.0, 1.0), (2.0, 3.0)])
672
def test_backend_correctness(
673
674
675
676
677
    default_vllm_config,
    dist_init,
    batch_spec_name: str,
    model: str,
    tensor_parallel_size: int,
678
    kv_cache_dtype: str,
679
680
    q_scale: float,
    k_scale: float,
681
):
Matthew Bonanni's avatar
Matthew Bonanni committed
682
683
684
685
686
687
688
689
690
691
692
693
694
695
    """
    Test that all backends produce similar outputs to a reference implementation
    using torch.nn.functional.scaled_dot_product_attention.

    This test works by:
    1. Generating a batch of sequences with specified context and query lengths.
    2. Computing a ground-truth attention output using torch.sdpa on
       contiguous Q, K, and V tensors.
    3. Simulating vLLM's paged KV cache: It takes the context portion of the
       K/V tensors and manually places them into a paged buffer according to
       the test's (randomly generated) block table.
    4. Running each vLLM attention backend with the new queries and the
       simulated paged KV cache.
    5. Comparing the vLLM backend's output to the ground-truth SDPA output.
696
697
698
699
700

    Note: When tensor_parallel_size > 1, we simulate the head partitioning
    by overriding the model config to use fewer heads, without requiring
    multiple GPUs. This tests that backends work correctly with different
    head counts.
Matthew Bonanni's avatar
Matthew Bonanni committed
701
    """
702

703
704
705
706
707
708
    # Filter backends to those that support the requested kv_cache_dtype
    backends_to_test = [
        b
        for b in BACKENDS_TO_TEST
        if kv_cache_dtype in b.get_class().supported_kv_cache_dtypes
    ]
709
710
711
712
713
    if (
        q_scale != 1.0 or k_scale != 1.0
    ) and AttentionBackendEnum.CUTLASS_MLA in backends_to_test:
        # CUTLASS_MLA does not support non-1 Q/K scales
        backends_to_test.remove(AttentionBackendEnum.CUTLASS_MLA)
714
715
716
    if not backends_to_test:
        pytest.skip(f"No backends support kv_cache_dtype={kv_cache_dtype}")

Matthew Bonanni's avatar
Matthew Bonanni committed
717
    batch_spec = BATCH_SPECS[batch_spec_name]
718
    is_spec_decode_test = batch_spec_name.startswith("spec_decode")
719
    unique_block_sizes = sorted(set(BACKEND_BLOCK_SIZES[b] for b in backends_to_test))
720
    default_block_size = unique_block_sizes[0]
721
    required_blocks = sum(
722
723
        (seq_len + default_block_size - 1) // default_block_size
        for seq_len in batch_spec.seq_lens
724
725
726
727
    )
    # Add 1 for null block at index 0, and some buffer
    num_gpu_blocks = required_blocks + 1 + 100

728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
    hf_config_override = None
    if tensor_parallel_size > 1:
        from vllm.config import ModelConfig

        temp_config = ModelConfig(model=model, max_model_len=1)
        original_num_heads = temp_config.hf_text_config.num_attention_heads
        original_num_kv_heads = getattr(
            temp_config.hf_text_config, "num_key_value_heads", None
        )
        hf_config_override = {
            "num_attention_heads": original_num_heads // tensor_parallel_size,
        }
        if original_num_kv_heads is not None:
            hf_config_override["num_key_value_heads"] = max(
                1, original_num_kv_heads // tensor_parallel_size
            )

745
    vllm_config = create_vllm_config(
746
        model_name=model,
747
        tensor_parallel_size=1,  # Always use TP=1 to avoid multi-GPU requirements
748
749
        max_model_len=max(batch_spec.seq_lens),
        num_gpu_blocks=num_gpu_blocks,
750
        block_size=default_block_size,
751
        hf_config_override=hf_config_override,
752
    )
753
    vllm_config.cache_config.cache_dtype = kv_cache_dtype
754
755
756
757
758
759
760
761
762
763
764
765
766
767

    # For spec decode tests, add a speculative_config to set the reorder_batch_threshold
    if is_spec_decode_test:
        from vllm.config import SpeculativeConfig

        # Get the query length from the batch spec (they should all be uniform)
        query_len = batch_spec.query_lens[0]
        # Set num_speculative_tokens to query_len - 1
        # (since threshold is 1 + num_spec_tokens)
        # Use ngram method which doesn't require a draft model
        vllm_config.speculative_config = SpeculativeConfig(
            method="ngram", num_speculative_tokens=query_len - 1
        )

768
    device = torch.device(f"{DEVICE_TYPE}:0")
Matthew Bonanni's avatar
Matthew Bonanni committed
769
770
771
772
773
774

    # 1. Setup
    batch_size = batch_spec.batch_size
    seq_lens = batch_spec.seq_lens
    query_lens = batch_spec.query_lens
    num_q_heads = vllm_config.model_config.get_num_attention_heads(
775
776
        vllm_config.parallel_config
    )
Matthew Bonanni's avatar
Matthew Bonanni committed
777
778
779
780
781
782
783
    head_size = vllm_config.model_config.get_head_size()
    dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype)
    kv_lora_rank = 512
    qk_rope_head_dim = 64
    qk_nope_head_dim = 128
    v_head_dim = 128
    total_head_size = kv_lora_rank + qk_rope_head_dim
784
    assert kv_lora_rank + qk_rope_head_dim == head_size, (
Matthew Bonanni's avatar
Matthew Bonanni committed
785
        f"MLA dimensions don't match: {total_head_size} != {head_size}"
786
    )
Matthew Bonanni's avatar
Matthew Bonanni committed
787
788
789
790
    scale = 1.0 / (total_head_size**0.5)

    # 2. Generate data and compute SDPA reference output for MLA
    all_q_vllm, all_kv_c_vllm, all_k_pe_vllm = [], [], []
791
    all_sdpa_outputs: list[list[torch.Tensor]] = []
Matthew Bonanni's avatar
Matthew Bonanni committed
792
793
794
    kv_c_contexts, k_pe_contexts = [], []

    # Create shared MLA weight matrices for consistency across all sequences
795
796
797
798
799
800
    W_UK = torch.randn(
        kv_lora_rank, num_q_heads, qk_nope_head_dim, dtype=dtype, device=device
    )
    W_UV = torch.randn(
        kv_lora_rank, num_q_heads, v_head_dim, dtype=dtype, device=device
    )
801
802
803
804
805
806
807
808

    # Scale weights to produce realistic magnitude outputs.
    # Without scaling, projection output has std ~sqrt(kv_lora_rank) ≈ 22.6,
    # causing extreme attention scores and numerical instability in LSE merging.
    weight_scale = 1.0 / (kv_lora_rank**0.5)
    W_UK = W_UK * weight_scale
    W_UV = W_UV * weight_scale

Matthew Bonanni's avatar
Matthew Bonanni committed
809
810
    kv_b_proj_weight = torch.cat([W_UK, W_UV], dim=-1)

811
    for i, backend in enumerate(backends_to_test):
812
813
        all_sdpa_outputs.append([])

Matthew Bonanni's avatar
Matthew Bonanni committed
814
815
816
817
818
819
820
821
    for i in range(batch_size):
        s_len = seq_lens[i]
        q_len = query_lens[i]
        context_len = s_len - q_len

        # Generate MLA tensors
        # Q has both nope and rope components:
        # [q_len, num_heads, qk_nope_head_dim + qk_rope_head_dim]
822
823
824
825
826
827
828
        q_c = torch.randn(
            q_len,
            num_q_heads,
            qk_nope_head_dim + qk_rope_head_dim,
            dtype=dtype,
            device=device,
        )
Matthew Bonanni's avatar
Matthew Bonanni committed
829
830

        # KV_C (latent K/V): [s_len, kv_lora_rank]
831
        kv_c_full = torch.randn(s_len, kv_lora_rank, dtype=dtype, device=device)
Matthew Bonanni's avatar
Matthew Bonanni committed
832
833

        # K_PE (rope component): [s_len, 1, qk_rope_head_dim]
834
        k_pe_full = torch.randn(s_len, 1, qk_rope_head_dim, dtype=dtype, device=device)
Matthew Bonanni's avatar
Matthew Bonanni committed
835

836
837
838
839
840
841
842
843
        # Determine if this sequence uses the decode pipeline or prefill
        # pipeline for each backend
        # NOTE: For spec decode tests with uniform query_len > 1, backends that
        # support spec decode (FLASH_ATTN_MLA with varlen support, FLASHMLA with
        # uniform support) will use the decode pipeline (MQA-style), while
        # backends that only support single-token queries will use the prefill
        # pipeline (MHA-style). This ensures the reference implementation
        # matches each backend's actual decode/prefill pipeline path.
844
        is_decode = []
845
        for backend_idx, backend in enumerate(backends_to_test):
846
            builder_cls, _ = try_get_attention_backend(backend)
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
            if is_spec_decode_test:
                query_len_support = getattr(
                    builder_cls, "query_len_support", QueryLenSupport.SINGLE_ONLY
                )
                supports_spec = query_len_support != QueryLenSupport.SINGLE_ONLY
                is_decode.append(supports_spec)
            else:
                threshold = getattr(builder_cls, "reorder_batch_threshold", None)
                query_len_support = getattr(
                    builder_cls, "query_len_support", QueryLenSupport.SINGLE_ONLY
                )
                within_threshold = q_len <= threshold if threshold else False
                if (
                    within_threshold
                    and query_len_support == QueryLenSupport.UNIFORM
                    and i > 0
                ):
                    first_q_len = query_lens[0]
                    within_threshold = q_len == first_q_len
                is_decode.append(within_threshold)
Matthew Bonanni's avatar
Matthew Bonanni committed
867
868
869
870

        # Split q into nope and rope components
        q_nope, q_pe = q_c.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1)

871
872
873
874
875
        #######################################################
        # Decode path: MQA-style attention in latent space
        # Transform q_nope to latent space: q_nope @ W_UK
        # q_nope: [1, num_heads, qk_nope_head_dim]
        # W_UK: [kv_lora_rank, num_heads, qk_nope_head_dim]
876
877
878
        ql_nope = torch.einsum(
            "qnh,lnh->qnl", q_nope, W_UK
        )  # [1, num_heads, kv_lora_rank]
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903

        # Build MQA attention inputs
        # Q: [1, num_heads, kv_lora_rank + qk_rope_head_dim]
        q_mqa = torch.cat([ql_nope, q_pe], dim=-1)
        # K: [s_len, kv_lora_rank + qk_rope_head_dim]
        # (broadcasted to all heads)
        k_mqa = torch.cat([kv_c_full, k_pe_full.squeeze(1)], dim=-1)
        k_mqa = k_mqa.unsqueeze(1).expand(-1, num_q_heads, -1)
        # V: [s_len, kv_lora_rank] (broadcasted to all heads)
        v_mqa = kv_c_full.unsqueeze(1).expand(-1, num_q_heads, -1)

        # Create custom attention mask for decode path:
        # - Query tokens can attend to all context tokens
        # - Query tokens can only attend to query tokens up to their position
        attn_mask = torch.ones(q_len, s_len, dtype=torch.bool, device=device)
        # Apply causal mask only to the query portion (context_len onwards)
        causal_mask = torch.tril(torch.ones(q_len, q_len, device=device))
        attn_mask[:, context_len:] = causal_mask

        # SDPA expects (N, H, L, D)
        q_sdpa_in = q_mqa.unsqueeze(0).transpose(1, 2)
        k_sdpa_in = k_mqa.unsqueeze(0).transpose(1, 2)
        v_sdpa_in = v_mqa.unsqueeze(0).transpose(1, 2)

        sdpa_out_i_decode = torch.nn.functional.scaled_dot_product_attention(
904
905
            q_sdpa_in, k_sdpa_in, v_sdpa_in, attn_mask=attn_mask, scale=scale
        )
906
        sdpa_out_i_decode = sdpa_out_i_decode.transpose(1, 2).squeeze(
907
908
            0
        )  # [1, num_heads, kv_lora_rank]
909
910

        # Project back to output space: sdpa_out @ W_UV
911
        sdpa_out_i_decode = torch.einsum("qnl,lnv->qnv", sdpa_out_i_decode, W_UV)
912
913
914
915
916
917
        sdpa_out_i_decode = sdpa_out_i_decode.flatten(start_dim=-2)

        #######################################################
        # Prefill path: MHA-style attention with full sequence
        # Apply kv_b_proj to the full kv_c tensor
        kv_nope_full = torch.einsum("sl,lnh->snh", kv_c_full, kv_b_proj_weight)
918
        k_nope_full, v_full = kv_nope_full.split([qk_nope_head_dim, v_head_dim], dim=-1)
919
920

        # Build attention inputs for full sequence
921
        q_mha = torch.cat([q_nope, q_pe], dim=-1)  # [q_len, num_heads, total_dim]
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
        k_pe_full_expanded = k_pe_full.expand(-1, num_q_heads, -1)
        k_full = torch.cat([k_nope_full, k_pe_full_expanded], dim=-1)

        # Create custom attention mask:
        # - Query tokens can attend to all context tokens
        # - Query tokens can only attend to query tokens up to their pos
        attn_mask = torch.ones(q_len, s_len, dtype=torch.bool, device=device)
        # Apply causal mask only to the query portion (context_len onwards)
        causal_mask = torch.tril(torch.ones(q_len, q_len, device=device))
        attn_mask[:, context_len:] = causal_mask

        # SDPA expects (N, H, L, D)
        q_sdpa_in = q_mha.unsqueeze(0).transpose(1, 2)
        k_sdpa_in = k_full.unsqueeze(0).transpose(1, 2)
        v_sdpa_in = v_full.unsqueeze(0).transpose(1, 2)

        # Single attention call with custom mask
        sdpa_out_i_prefill = torch.nn.functional.scaled_dot_product_attention(
940
941
            q_sdpa_in, k_sdpa_in, v_sdpa_in, attn_mask=attn_mask, scale=scale
        )
942
943
944
        sdpa_out_i_prefill = sdpa_out_i_prefill.transpose(1, 2).squeeze(0)
        sdpa_out_i_prefill = sdpa_out_i_prefill.flatten(start_dim=-2)

945
        for backend_idx, backend in enumerate(backends_to_test):
946
947
            if is_decode[backend_idx]:
                all_sdpa_outputs[backend_idx].append(sdpa_out_i_decode)
948
            else:
949
                all_sdpa_outputs[backend_idx].append(sdpa_out_i_prefill)
Matthew Bonanni's avatar
Matthew Bonanni committed
950
951
952
953
954
955
956
957
958
959
960
961
962
963

        # Inputs for vLLM MLA backends are just the new tokens
        all_q_vllm.append(q_c)
        all_kv_c_vllm.append(kv_c_full[context_len:])  # New kv_c tokens
        all_k_pe_vllm.append(k_pe_full[context_len:])  # New k_pe tokens

        # Contextual K/V data used to populate the paged cache (MLA format)
        kv_c_contexts.append(kv_c_full[:context_len])
        k_pe_contexts.append(k_pe_full[:context_len])

    # Concatenate all sequences (no reordering needed)
    query_vllm = torch.cat(all_q_vllm, dim=0)
    kv_c_vllm = torch.cat(all_kv_c_vllm, dim=0)
    k_pe_vllm = torch.cat(all_k_pe_vllm, dim=0)
964
    sdpa_outputs = {}
965
    for backend_idx, backend in enumerate(backends_to_test):
966
        sdpa_outputs[backend] = torch.cat(all_sdpa_outputs[backend_idx], dim=0)
Matthew Bonanni's avatar
Matthew Bonanni committed
967
968
969

    # Create mock kv_b_proj using the same weights as reference implementation
    from vllm.model_executor.layers.linear import ColumnParallelLinear
970
971
972
973
974
975

    mock_kv_b_proj = ColumnParallelLinear(
        input_size=kv_lora_rank,
        output_size=num_q_heads * (qk_nope_head_dim + v_head_dim),
        bias=False,
    ).to(device=device, dtype=dtype)
Matthew Bonanni's avatar
Matthew Bonanni committed
976
977
978
979
980

    # Set the mock weights to match our reference implementation
    # Reshape W_UK and W_UV to match the expected kv_b_proj format
    # [kv_lora_rank, num_heads, qk_nope_head_dim + v_head_dim]
    kv_b_proj_weight = kv_b_proj_weight.view(
981
982
        kv_lora_rank, num_q_heads * (qk_nope_head_dim + v_head_dim)
    )
983
    mock_kv_b_proj.weight = torch.nn.Parameter(kv_b_proj_weight.T, requires_grad=False)
Matthew Bonanni's avatar
Matthew Bonanni committed
984

985
986
987
988
    # 3. Create metadata and KV caches for each block size
    # Group backends by block size and test each group
    metadata_per_block_size = {}
    kv_cache_per_block_size = {}
Matthew Bonanni's avatar
Matthew Bonanni committed
989

990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
    for block_size in unique_block_sizes:
        # Create metadata for this block size
        common_attn_metadata = create_common_attn_metadata(
            batch_spec, block_size, device
        )

        # Pad block table to meet requirement:
        # block_num % (128 / block_size) == 0
        required_divisor = int(128 / block_size)
        current_block_num = common_attn_metadata.block_table_tensor.shape[1]
        if current_block_num % required_divisor != 0:
            # Pad to next multiple of required_divisor
            padded_block_num = (
                (current_block_num + required_divisor - 1) // required_divisor
            ) * required_divisor
            padding_cols = padded_block_num - current_block_num
            padding = torch.zeros(
                (common_attn_metadata.block_table_tensor.shape[0], padding_cols),
                dtype=torch.int32,
                device=device,
            )
            common_attn_metadata.block_table_tensor = torch.cat(
                [common_attn_metadata.block_table_tensor, padding], dim=1
            )

        metadata_per_block_size[block_size] = common_attn_metadata

        # Create KV cache for this block size
        required_blocks_for_size = sum(
            (seq_len + block_size - 1) // block_size for seq_len in batch_spec.seq_lens
        )
        num_blocks_for_size = required_blocks_for_size + 1 + 100

        kv_cache = create_and_prepopulate_kv_cache(
            kv_c_contexts=kv_c_contexts,
            k_pe_contexts=k_pe_contexts,
            block_size=block_size,
            head_size=head_size,
            dtype=dtype,
            device=device,
            num_blocks=num_blocks_for_size,
            common_attn_metadata=common_attn_metadata,
            randomize_blocks=True,
1033
            kv_cache_dtype=kv_cache_dtype,
1034
            scale=k_scale,
1035
1036
        )
        kv_cache_per_block_size[block_size] = kv_cache
Matthew Bonanni's avatar
Matthew Bonanni committed
1037
1038

    # 4. Run vLLM backends and compare
1039
    failures = []
1040
    for backend_idx, backend_name in enumerate(backends_to_test):
1041
        # Skip backends that don't support spec decode for spec decode tests
1042
        if is_spec_decode_test and backend_name not in SPEC_DECODE_BACKENDS:
1043
1044
            continue

1045
1046
1047
1048
1049
1050
        # Get the appropriate block_size, metadata, and cache for this backend
        block_size = BACKEND_BLOCK_SIZES[backend_name]
        common_attn_metadata = metadata_per_block_size[block_size]
        kv_cache = kv_cache_per_block_size[block_size]

        # Create kv_cache_spec with the correct block_size for this backend
1051
        backend_kv_cache_spec = MLAAttentionSpec(
1052
1053
1054
1055
1056
1057
1058
            block_size=block_size,
            num_kv_heads=vllm_config.model_config.get_num_kv_heads(
                vllm_config.parallel_config
            ),
            head_size=vllm_config.model_config.get_head_size(),
            dtype=vllm_config.model_config.dtype,
            sliding_window=vllm_config.model_config.get_sliding_window(),
1059
            cache_dtype_str=kv_cache_dtype,
1060
1061
        )

Matthew Bonanni's avatar
Matthew Bonanni committed
1062
        backend_output = run_attention_backend(
1063
            backend_name,
1064
            backend_kv_cache_spec,
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
            ["placeholder"],
            vllm_config,
            device,
            common_attn_metadata,
            query_vllm,
            kv_c_vllm,
            k_pe_vllm,
            kv_cache,
            kv_lora_rank,
            qk_nope_head_dim,
            qk_rope_head_dim,
            v_head_dim,
            mock_kv_b_proj,
1078
1079
            q_scale=q_scale,
            k_scale=k_scale,
1080
            kv_cache_dtype=kv_cache_dtype,
1081
        )
Matthew Bonanni's avatar
Matthew Bonanni committed
1082

1083
1084
1085
        # Use backend_idx to get the correct SDPA output for this backend
        expected_output = sdpa_outputs[backend_name]

Matthew Bonanni's avatar
Matthew Bonanni committed
1086
        # Check shape and dtype consistency
1087
1088
1089
1090
1091
1092
1093
1094
1095
        try:
            assert backend_output.shape == expected_output.shape, (
                f"[{backend_name}] shape {backend_output.shape} != "
                f"SDPA shape {expected_output.shape}"
            )
            assert backend_output.dtype == expected_output.dtype, (
                f"[{backend_name}] dtype {backend_output.dtype} != "
                f"SDPA dtype {expected_output.dtype}"
            )
Matthew Bonanni's avatar
Matthew Bonanni committed
1096

1097
1098
1099
            assert torch.isfinite(backend_output).all(), (
                f"[{backend_name}] produced non-finite values"
            )
Matthew Bonanni's avatar
Matthew Bonanni committed
1100

1101
1102
1103
            # Check numerical similarity
            rtol = 1e-2
            atol = 5e-1
Matthew Bonanni's avatar
Matthew Bonanni committed
1104

1105
1106
1107
1108
1109
1110
1111
            max_diff = torch.max(torch.abs(backend_output - expected_output)).item()
            max_rel_diff = torch.max(
                torch.abs(backend_output - expected_output) / torch.abs(expected_output)
            ).item()
            all_close = torch.allclose(
                backend_output, expected_output, rtol=rtol, atol=atol
            )
Matthew Bonanni's avatar
Matthew Bonanni committed
1112

1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
            assert all_close, (
                f"[{backend_name}] output differs from SDPA baseline. "
                f"Max diff: {max_diff:.6f}, max rel diff: {max_rel_diff:.6f})"
            )
        except AssertionError as e:
            failures.append(str(e))

    # Report all failures at once
    if failures:
        # Create a summary for the single-line failure message
        backend_names = []
        for f in failures:
1125
            if "[AttentionBackendEnum." in f:
1126
1127
1128
1129
1130
1131
                backend_name = f.split("[")[1].split("]")[0]
                backend_names.append(backend_name)

        summary = f"{len(failures)} backend(s) failed: {', '.join(backend_names)}"
        detailed_msg = "\n".join(failures)
        pytest.fail(f"{summary}\n{detailed_msg}")