test_sparse_mla_backends.py 15 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Unit tests for the FlashMLA sparse backend utilities."""

import math
from types import MethodType, SimpleNamespace

import numpy as np
import pytest
import torch

from tests.v1.attention.test_mla_backends import (
13
14
15
16
17
18
19
20
21
22
    BATCH_SPECS,
    BatchSpec,
    MockAttentionLayer,
    create_and_prepopulate_kv_cache,
)
from tests.v1.attention.utils import (
    create_common_attn_metadata,
    create_standard_kv_cache_spec,
    create_vllm_config,
)
23
24
25
26
27
from vllm import _custom_ops as ops
from vllm.attention.ops import flashmla
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.utils import cdiv
from vllm.v1.attention.backends.mla.flashmla_sparse import (
28
29
30
31
32
    FlashMLASparseBackend,
    FlashMLASparseDecodeAndContextMetadata,
    FlashMLASparseImpl,
    FlashMLASparseMetadata,
)
33
from vllm.v1.attention.backends.mla.indexer import split_prefill_chunks
34
35
36
37
38
39
40
41
42
43
44
45

SPARSE_BACKEND_BATCH_SPECS = {
    name: BATCH_SPECS[name]
    for name in [
        "mixed_small",
        "mixed_medium",
        "small_prefill",
        "medium_prefill",
        "single_prefill",
    ]
}

46
47
48
SPARSE_BACKEND_BATCH_SPECS["large_q_prefill"] = BatchSpec(
    seq_lens=[1024] * 2, query_lens=[256] * 2
)
49
SPARSE_BACKEND_BATCH_SPECS["large_q_pure_prefill"] = BatchSpec(
50
51
    seq_lens=[256] * 2, query_lens=[256] * 2
)
52
53
54


def _dequantize_fp8_ds_mla_entry(
55
56
    cache_slice: torch.Tensor, kv_lora_rank: int, rope_dim: int, dtype: torch.dtype
) -> tuple[torch.Tensor, torch.Tensor]:
57
58
59
60
    """Dequantize a single fp8_ds_mla cache entry back to latent + rope."""

    # The first kv_lora_rank bytes store FP8 latent values with one scale per
    # 128 element tile written as float32 right after the latent payload.
61
62
    scales = cache_slice.view(torch.float32)[kv_lora_rank // 4 : kv_lora_rank // 4 + 4]
    latent = torch.empty(kv_lora_rank, dtype=torch.float16, device=cache_slice.device)
63
64
65
    for tile_idx in range(4):
        tile_start = tile_idx * 128
        tile_end = tile_start + 128
66
67
68
69
70
71
        ops.convert_fp8(
            latent[tile_start:tile_end],
            cache_slice[tile_start:tile_end],
            float(scales[tile_idx].item()),
            kv_dtype="fp8",
        )
72
73
74
    latent = latent.to(dtype)

    rope_offset = kv_lora_rank // 2 + 8
75
    rope_vals = cache_slice.view(dtype)[rope_offset : rope_offset + rope_dim]
76
77
78
79
    return latent, rope_vals.clone()


def _quantize_dequantize_fp8_ds_mla(
80
81
    kv_c: torch.Tensor, k_pe: torch.Tensor, block_size: int, scale: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
82
83
84
85
86
87
88
89
90
91
92
    """Round-trip kv_c/k_pe though the fp8_ds_mla cache layout."""

    if kv_c.numel() == 0:
        return kv_c.clone(), k_pe.clone()

    kv_lora_rank = kv_c.shape[-1]
    rope_dim = k_pe.shape[-1]
    num_tokens = kv_c.shape[0]
    num_blocks = max(1, math.ceil(num_tokens / block_size))
    entry_size = kv_lora_rank + 4 * 4 + 2 * rope_dim

93
94
95
96
97
98
99
100
    tmp_cache = torch.zeros(
        num_blocks, block_size, entry_size, dtype=torch.uint8, device=kv_c.device
    )
    slot_mapping = torch.arange(num_tokens, dtype=torch.long, device=kv_c.device)

    ops.concat_and_cache_mla(
        kv_c, k_pe, tmp_cache, slot_mapping, kv_cache_dtype="fp8_ds_mla", scale=scale
    )
101
102
103
104
105
106
107
108
109
110

    dequant_kv_c = torch.empty_like(kv_c)
    dequant_k_pe = torch.empty_like(k_pe)

    for token_idx in range(num_tokens):
        slot = slot_mapping[token_idx].item()
        block_idx = slot // block_size
        block_offset = slot % block_size
        cache_slice = tmp_cache[block_idx, block_offset]
        latent, rope_vals = _dequantize_fp8_ds_mla_entry(
111
112
            cache_slice, kv_lora_rank, rope_dim, kv_c.dtype
        )
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
        dequant_kv_c[token_idx] = latent
        dequant_k_pe[token_idx] = rope_vals

    return dequant_kv_c, dequant_k_pe


def test_sparse_backend_metadata_registration():
    backend = FlashMLASparseBackend

    assert backend.get_name() == "FLASHMLA_SPARSE_VLLM_V1"
    assert backend.get_metadata_cls() is FlashMLASparseMetadata
    assert backend.get_impl_cls() is FlashMLASparseImpl

    dtype_list = backend.get_supported_dtypes()
    assert torch.bfloat16 in dtype_list

129
130
131
    shape = backend.get_kv_cache_shape(
        num_blocks=2, block_size=64, num_kv_heads=1, head_size=576
    )
132
133
134
135
136
137
138
139
140
141
142
143
144
145
    assert shape == (2, 64, 576)


def test_sparse_decode_metadata_filters_prefill_indices():
    prefill_context_lengths = torch.tensor([4, 2], dtype=torch.int32)
    metadata = FlashMLASparseDecodeAndContextMetadata(
        scheduler_metadata=torch.tensor([[0]], dtype=torch.int32),
        num_splits=torch.tensor([1, 1], dtype=torch.int32),
        cache_lens=torch.tensor([10, 12], dtype=torch.int32),
        prefill_context_lengths=prefill_context_lengths,
    )

    indices = torch.tensor([[0, 3, 5], [1, 2, 4]], dtype=torch.int32)

146
    context_indices, new_token_indices = metadata.filter_prefill_indices(indices)
147

148
149
    expected_context = torch.tensor([[-1, -1, 5], [-1, -1, 4]], dtype=torch.int32)
    expected_new_tokens = torch.tensor([[-1, -1, 1], [-1, 0, 2]], dtype=torch.int32)
150
151
152
153
154
155
156
157
158
159
160
161
162
163

    assert torch.equal(context_indices, expected_context)
    assert torch.equal(new_token_indices, expected_new_tokens)


def test_sparse_impl_zero_fills_when_metadata_missing():
    impl = FlashMLASparseImpl.__new__(FlashMLASparseImpl)
    dummy_layer = object()
    q = torch.zeros((2, 1, 3))
    k_c = torch.zeros((2, 3))
    k_pe = torch.zeros((2, 1, 1))
    kv_cache = torch.zeros((1, 1, 1))
    output = torch.ones((2, 4))

164
165
166
    result = FlashMLASparseImpl.forward(
        impl, dummy_layer, q, k_c, k_pe, kv_cache, attn_metadata=None, output=output
    )
167
168
169
170
171
172
173

    assert result is output
    assert torch.all(result == 0)


@pytest.mark.parametrize("batch_name", list(SPARSE_BACKEND_BATCH_SPECS.keys()))
@pytest.mark.parametrize("kv_cache_dtype", ["fp8_ds_mla", "auto"])
174
def test_sparse_backend_decode_correctness(dist_init, batch_name, kv_cache_dtype):
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
    if not torch.cuda.is_available():
        pytest.skip("CUDA is required for sparse MLA decode test")

    device = torch.device("cuda")
    dtype = torch.bfloat16

    batch_spec = SPARSE_BACKEND_BATCH_SPECS[batch_name]

    # Model hyper-parameters (kept intentionally small for the unit test)
    num_heads = 128
    kv_lora_rank = 512
    qk_nope_head_dim = 128
    qk_rope_head_dim = 64
    v_head_dim = 128
    head_size = kv_lora_rank + qk_rope_head_dim
    topk_tokens = 2048

    max_seqlen = max(batch_spec.seq_lens)
    total_cache_tokens = sum(batch_spec.seq_lens)
    block_size = 64

    vllm_config = create_vllm_config(
        model_name="deepseek-ai/DeepSeek-V2-Lite-Chat",
        max_model_len=max_seqlen,
199
200
201
        num_gpu_blocks=max(2048, cdiv(total_cache_tokens, block_size) + 1),
        block_size=block_size,
    )
202
203
    model_config = vllm_config.model_config
    model_config.hf_config = SimpleNamespace(
204
205
        attn_module_list_cfg=[{"topk_tokens": topk_tokens}]
    )
206
207
208
209
210
211
212
213
214
215
    model_config.hf_text_config = SimpleNamespace(
        q_lora_rank=None,
        kv_lora_rank=kv_lora_rank,
        qk_nope_head_dim=qk_nope_head_dim,
        qk_rope_head_dim=qk_rope_head_dim,
        v_head_dim=v_head_dim,
        model_type="deepseek_v2",
    )
    model_config.dtype = dtype
    model_config.get_num_attention_heads = MethodType(
216
217
218
219
220
221
222
        lambda self, parallel_config: num_heads, model_config
    )
    model_config.get_num_kv_heads = MethodType(
        lambda self, parallel_config: 1, model_config
    )
    model_config.get_head_size = MethodType(lambda self: head_size, model_config)
    model_config.get_sliding_window = MethodType(lambda self: None, model_config)
223
224
225
226
227
228
229
230

    kv_cache_spec = create_standard_kv_cache_spec(vllm_config)

    torch.manual_seed(0)

    scale = 1.0 / math.sqrt(head_size)

    # Shared MLA projection weights to keep reference and backend in sync
231
232
233
234
    W_UK = torch.randn(
        kv_lora_rank, num_heads, qk_nope_head_dim, dtype=dtype, device=device
    )
    W_UV = torch.randn(kv_lora_rank, num_heads, v_head_dim, dtype=dtype, device=device)
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250

    # Build synthetic decode-only workload
    seq_lens = batch_spec.seq_lens
    query_lens = batch_spec.query_lens

    all_q_vllm, all_kv_c_vllm, all_k_pe_vllm = [], [], []
    kv_c_contexts, k_pe_contexts = [], []
    reference_outputs = []

    kv_cache_scale = torch.tensor(1.0, dtype=torch.float32, device=device)

    for i in range(batch_spec.batch_size):
        s_len = seq_lens[i]
        q_len = query_lens[i]
        ctx_len = s_len - q_len

251
252
253
254
255
256
257
        q_c = torch.rand(
            q_len,
            num_heads,
            qk_nope_head_dim + qk_rope_head_dim,
            dtype=dtype,
            device=device,
        )
258
        kv_c_full = torch.rand(s_len, kv_lora_rank, dtype=dtype, device=device)
259
        k_pe_full = torch.rand(s_len, 1, qk_rope_head_dim, dtype=dtype, device=device)
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284

        kv_c_full, k_pe_full = _quantize_dequantize_fp8_ds_mla(
            kv_c_full,
            k_pe_full.squeeze(1),
            block_size=vllm_config.cache_config.block_size,
            scale=kv_cache_scale,
        )

        q_nope, q_pe = q_c.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1)
        ql_nope = torch.einsum("qnh,lnh->qnl", q_nope, W_UK)
        q_mqa = torch.cat([ql_nope, q_pe], dim=-1)

        k_mqa = torch.cat([kv_c_full, k_pe_full], dim=-1)
        k_mqa = k_mqa.unsqueeze(1).expand(-1, num_heads, -1)
        v_mqa = kv_c_full.unsqueeze(1).expand(-1, num_heads, -1)

        attn_mask = torch.ones(q_len, s_len, dtype=torch.bool, device=device)
        causal_mask = torch.tril(torch.ones(q_len, q_len, device=device))
        attn_mask[:, ctx_len:] = causal_mask

        q_sdpa_in = q_mqa.unsqueeze(0).transpose(1, 2)
        k_sdpa_in = k_mqa.unsqueeze(0).transpose(1, 2)
        v_sdpa_in = v_mqa.unsqueeze(0).transpose(1, 2)

        sdpa_out = torch.nn.functional.scaled_dot_product_attention(
285
286
            q_sdpa_in, k_sdpa_in, v_sdpa_in, attn_mask=attn_mask, scale=scale
        )
287
288
289
290
291
292
293
294
        sdpa_out = sdpa_out.transpose(1, 2).squeeze(0)

        sdpa_out = torch.einsum("qnl,lnv->qnv", sdpa_out, W_UV)
        reference_outputs.append(sdpa_out.flatten(start_dim=-2))

        all_q_vllm.append(q_c)
        all_kv_c_vllm.append(kv_c_full[ctx_len:])
        all_k_pe_vllm.append(k_pe_full[ctx_len:])
295
296
        kv_c_contexts.append(kv_c_full[: ctx_len + 1])
        k_pe_contexts.append(k_pe_full[: ctx_len + 1])
297
298
299
300
301
302
303
304
305
306
307
308

    query_vllm = torch.cat(all_q_vllm, dim=0)
    kv_c_vllm = torch.cat(all_kv_c_vllm, dim=0)
    k_pe_vllm = torch.cat(all_k_pe_vllm, dim=0)
    sdpa_reference = torch.cat(reference_outputs, dim=0)

    vllm_config.cache_config.cache_dtype = kv_cache_dtype

    common_attn_metadata = create_common_attn_metadata(
        batch_spec,
        vllm_config.cache_config.block_size,
        device,
309
310
        arange_block_indices=True,
    )
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327

    kv_cache = create_and_prepopulate_kv_cache(
        kv_c_contexts=kv_c_contexts,
        k_pe_contexts=k_pe_contexts,
        block_size=vllm_config.cache_config.block_size,
        head_size=head_size,
        dtype=dtype,
        device=device,
        num_blocks=vllm_config.cache_config.num_gpu_blocks,
        common_attn_metadata=common_attn_metadata,
        randomize_blocks=False,
        kv_cache_dtype=vllm_config.cache_config.cache_dtype,
        scale=kv_cache_scale,
    )

    builder_cls = FlashMLASparseBackend.get_builder_cls()
    builder = builder_cls(kv_cache_spec, ["placeholder"], vllm_config, device)
328
329
330
    metadata = builder.build(
        common_prefix_len=0, common_attn_metadata=common_attn_metadata
    )
331

332
    starts = np.asarray(common_attn_metadata.query_start_loc_cpu, dtype=np.int32)
333
334
    seg_lengths = np.diff(starts)
    positions = np.arange(starts[-1], dtype=np.int32) - np.repeat(
335
336
        starts[:-1], seg_lengths
    )
337
338
339
340
341
342
    seq_lengths = np.asarray(common_attn_metadata.seq_lens_cpu, dtype=np.int32)
    prefix_lengths = seq_lengths - seg_lengths
    positions += np.repeat(prefix_lengths, seg_lengths)

    pos_gpu = torch.as_tensor(positions, device=device, dtype=torch.int32)
    topk = metadata.topk_tokens
343
    debug_indices = torch.arange(topk, device=device, dtype=torch.int32).unsqueeze(0)
344
    token_positions = pos_gpu.unsqueeze(1)
345
346
347
348
    causal_mask = debug_indices <= token_positions
    debug_indices = torch.where(
        causal_mask, debug_indices, torch.full_like(debug_indices, -1)
    )
349
350
351

    # FlashMLASparseImpl now reads top-k indices from the indexer-provided
    # buffer, so emulate that contract with a simple namespace mock.
352
    debug_indices = debug_indices.expand(metadata.num_actual_tokens, -1).clone()
353
354
355
356
357
358
359
360
    mock_indexer = SimpleNamespace(topk_indices_buffer=debug_indices)

    ok, reason = flashmla.is_flashmla_supported()
    if not ok:
        pytest.skip(reason)

    kv_b_proj_weight = torch.cat([W_UK, W_UV], dim=-1)
    kv_b_proj_weight = kv_b_proj_weight.view(
361
362
        kv_lora_rank, num_heads * (qk_nope_head_dim + v_head_dim)
    )
363

364
365
366
367
368
    mock_kv_b_proj = ColumnParallelLinear(
        input_size=kv_lora_rank,
        output_size=num_heads * (qk_nope_head_dim + v_head_dim),
        bias=False,
    ).to(device=device, dtype=dtype)
369
370
371
    mock_kv_b_proj.weight = torch.nn.Parameter(kv_b_proj_weight.T.contiguous())

    impl_cls = FlashMLASparseBackend.get_impl_cls()
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
    impl = impl_cls(
        num_heads=num_heads,
        head_size=head_size,
        scale=scale,
        num_kv_heads=1,
        alibi_slopes=None,
        sliding_window=None,
        kv_cache_dtype=vllm_config.cache_config.cache_dtype,
        logits_soft_cap=None,
        attn_type="decoder",
        kv_sharing_target_layer_name=None,
        q_lora_rank=None,
        kv_lora_rank=kv_lora_rank,
        qk_nope_head_dim=qk_nope_head_dim,
        qk_rope_head_dim=qk_rope_head_dim,
        qk_head_dim=qk_nope_head_dim + qk_rope_head_dim,
        v_head_dim=v_head_dim,
        kv_b_proj=mock_kv_b_proj,
        indexer=mock_indexer,
    )
392
393
394
395

    impl.process_weights_after_loading(dtype)

    layer = MockAttentionLayer(device)
396
397
398
399
400
401
402
    out_buffer = torch.empty(
        metadata.num_actual_tokens, num_heads * v_head_dim, dtype=dtype, device=device
    )

    backend_output = impl.forward(
        layer, query_vllm, kv_c_vllm, k_pe_vllm, kv_cache, metadata, output=out_buffer
    )
403
404
405
406
407

    assert backend_output.shape == sdpa_reference.shape
    assert backend_output.dtype == sdpa_reference.dtype
    assert torch.isfinite(backend_output).all()

408
    torch.testing.assert_close(backend_output, sdpa_reference, rtol=0.5, atol=0.5)
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429


@pytest.mark.parametrize(
    "seq_lens,max_buf,start,expected",
    [
        # Basic split: totals per chunk ≤ max_buf
        (torch.tensor([2, 3, 4, 2]), 5, 0, [(0, 2), (2, 3), (3, 4)]),
        # Non-zero start index
        (torch.tensor([2, 3, 4, 2]), 5, 1, [(1, 2), (2, 3), (3, 4)]),
        # Exact fits should split between items when adding the next would
        # overflow
        (torch.tensor([5, 5, 5]), 5, 0, [(0, 1), (1, 2), (2, 3)]),
        # All requests fit in a single chunk
        (torch.tensor([1, 1, 1]), 10, 0, [(0, 3)]),
        # Large buffer with non-zero start
        (torch.tensor([4, 4, 4]), 100, 1, [(1, 3)]),
    ],
)
def test_split_prefill_chunks(seq_lens, max_buf, start, expected):
    out = split_prefill_chunks(seq_lens, max_buf, start)
    assert out == expected