test_sparse_mla_backends.py 13.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Unit tests for the FlashMLA sparse backend utilities."""

import math
from types import MethodType, SimpleNamespace

import numpy as np
import pytest
import torch

from tests.v1.attention.test_mla_backends import (
13
14
15
16
17
18
19
20
21
22
    BATCH_SPECS,
    BatchSpec,
    MockAttentionLayer,
    create_and_prepopulate_kv_cache,
)
from tests.v1.attention.utils import (
    create_common_attn_metadata,
    create_standard_kv_cache_spec,
    create_vllm_config,
)
23
24
25
26
from vllm import _custom_ops as ops
from vllm.attention.ops import flashmla
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.utils import cdiv
27
from vllm.v1.attention.backends.mla.flashmla_sparse import FlashMLASparseBackend
28
from vllm.v1.attention.backends.mla.indexer import split_prefill_chunks
29
30
31
32
33
34
35
36
37
38
39
40

SPARSE_BACKEND_BATCH_SPECS = {
    name: BATCH_SPECS[name]
    for name in [
        "mixed_small",
        "mixed_medium",
        "small_prefill",
        "medium_prefill",
        "single_prefill",
    ]
}

41
42
43
SPARSE_BACKEND_BATCH_SPECS["large_q_prefill"] = BatchSpec(
    seq_lens=[1024] * 2, query_lens=[256] * 2
)
44
SPARSE_BACKEND_BATCH_SPECS["large_q_pure_prefill"] = BatchSpec(
45
46
    seq_lens=[256] * 2, query_lens=[256] * 2
)
47
48
49


def _dequantize_fp8_ds_mla_entry(
50
51
    cache_slice: torch.Tensor, kv_lora_rank: int, rope_dim: int, dtype: torch.dtype
) -> tuple[torch.Tensor, torch.Tensor]:
52
53
54
55
    """Dequantize a single fp8_ds_mla cache entry back to latent + rope."""

    # The first kv_lora_rank bytes store FP8 latent values with one scale per
    # 128 element tile written as float32 right after the latent payload.
56
57
    scales = cache_slice.view(torch.float32)[kv_lora_rank // 4 : kv_lora_rank // 4 + 4]
    latent = torch.empty(kv_lora_rank, dtype=torch.float16, device=cache_slice.device)
58
59
60
    for tile_idx in range(4):
        tile_start = tile_idx * 128
        tile_end = tile_start + 128
61
62
63
64
65
66
        ops.convert_fp8(
            latent[tile_start:tile_end],
            cache_slice[tile_start:tile_end],
            float(scales[tile_idx].item()),
            kv_dtype="fp8",
        )
67
68
69
    latent = latent.to(dtype)

    rope_offset = kv_lora_rank // 2 + 8
70
    rope_vals = cache_slice.view(dtype)[rope_offset : rope_offset + rope_dim]
71
72
73
74
    return latent, rope_vals.clone()


def _quantize_dequantize_fp8_ds_mla(
75
76
    kv_c: torch.Tensor, k_pe: torch.Tensor, block_size: int, scale: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
77
78
79
80
81
82
83
84
85
86
87
    """Round-trip kv_c/k_pe though the fp8_ds_mla cache layout."""

    if kv_c.numel() == 0:
        return kv_c.clone(), k_pe.clone()

    kv_lora_rank = kv_c.shape[-1]
    rope_dim = k_pe.shape[-1]
    num_tokens = kv_c.shape[0]
    num_blocks = max(1, math.ceil(num_tokens / block_size))
    entry_size = kv_lora_rank + 4 * 4 + 2 * rope_dim

88
89
90
91
92
93
94
95
    tmp_cache = torch.zeros(
        num_blocks, block_size, entry_size, dtype=torch.uint8, device=kv_c.device
    )
    slot_mapping = torch.arange(num_tokens, dtype=torch.long, device=kv_c.device)

    ops.concat_and_cache_mla(
        kv_c, k_pe, tmp_cache, slot_mapping, kv_cache_dtype="fp8_ds_mla", scale=scale
    )
96
97
98
99
100
101
102
103
104
105

    dequant_kv_c = torch.empty_like(kv_c)
    dequant_k_pe = torch.empty_like(k_pe)

    for token_idx in range(num_tokens):
        slot = slot_mapping[token_idx].item()
        block_idx = slot // block_size
        block_offset = slot % block_size
        cache_slice = tmp_cache[block_idx, block_offset]
        latent, rope_vals = _dequantize_fp8_ds_mla_entry(
106
107
            cache_slice, kv_lora_rank, rope_dim, kv_c.dtype
        )
108
109
110
111
112
113
114
115
        dequant_kv_c[token_idx] = latent
        dequant_k_pe[token_idx] = rope_vals

    return dequant_kv_c, dequant_k_pe


@pytest.mark.parametrize("batch_name", list(SPARSE_BACKEND_BATCH_SPECS.keys()))
@pytest.mark.parametrize("kv_cache_dtype", ["fp8_ds_mla", "auto"])
116
def test_sparse_backend_decode_correctness(dist_init, batch_name, kv_cache_dtype):
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
    if not torch.cuda.is_available():
        pytest.skip("CUDA is required for sparse MLA decode test")

    device = torch.device("cuda")
    dtype = torch.bfloat16

    batch_spec = SPARSE_BACKEND_BATCH_SPECS[batch_name]

    # Model hyper-parameters (kept intentionally small for the unit test)
    num_heads = 128
    kv_lora_rank = 512
    qk_nope_head_dim = 128
    qk_rope_head_dim = 64
    v_head_dim = 128
    head_size = kv_lora_rank + qk_rope_head_dim
    topk_tokens = 2048

    max_seqlen = max(batch_spec.seq_lens)
    total_cache_tokens = sum(batch_spec.seq_lens)
    block_size = 64

    vllm_config = create_vllm_config(
        model_name="deepseek-ai/DeepSeek-V2-Lite-Chat",
        max_model_len=max_seqlen,
141
142
        num_gpu_blocks=max(2048, cdiv(total_cache_tokens, block_size) + 1),
        block_size=block_size,
143
144
145
146
        hf_config_override={
            "index_topk": topk_tokens,
            "attn_module_list_cfg": [{"topk_tokens": topk_tokens}],
        },
147
    )
148
149
150
151
152
153
154
155
156
157
158
    model_config = vllm_config.model_config
    model_config.hf_text_config = SimpleNamespace(
        q_lora_rank=None,
        kv_lora_rank=kv_lora_rank,
        qk_nope_head_dim=qk_nope_head_dim,
        qk_rope_head_dim=qk_rope_head_dim,
        v_head_dim=v_head_dim,
        model_type="deepseek_v2",
    )
    model_config.dtype = dtype
    model_config.get_num_attention_heads = MethodType(
159
160
161
162
163
164
165
        lambda self, parallel_config: num_heads, model_config
    )
    model_config.get_num_kv_heads = MethodType(
        lambda self, parallel_config: 1, model_config
    )
    model_config.get_head_size = MethodType(lambda self: head_size, model_config)
    model_config.get_sliding_window = MethodType(lambda self: None, model_config)
166
167
168
169
170
171
172
173

    kv_cache_spec = create_standard_kv_cache_spec(vllm_config)

    torch.manual_seed(0)

    scale = 1.0 / math.sqrt(head_size)

    # Shared MLA projection weights to keep reference and backend in sync
174
175
176
177
    W_UK = torch.randn(
        kv_lora_rank, num_heads, qk_nope_head_dim, dtype=dtype, device=device
    )
    W_UV = torch.randn(kv_lora_rank, num_heads, v_head_dim, dtype=dtype, device=device)
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193

    # Build synthetic decode-only workload
    seq_lens = batch_spec.seq_lens
    query_lens = batch_spec.query_lens

    all_q_vllm, all_kv_c_vllm, all_k_pe_vllm = [], [], []
    kv_c_contexts, k_pe_contexts = [], []
    reference_outputs = []

    kv_cache_scale = torch.tensor(1.0, dtype=torch.float32, device=device)

    for i in range(batch_spec.batch_size):
        s_len = seq_lens[i]
        q_len = query_lens[i]
        ctx_len = s_len - q_len

194
195
196
197
198
199
200
        q_c = torch.rand(
            q_len,
            num_heads,
            qk_nope_head_dim + qk_rope_head_dim,
            dtype=dtype,
            device=device,
        )
201
        kv_c_full = torch.rand(s_len, kv_lora_rank, dtype=dtype, device=device)
202
        k_pe_full = torch.rand(s_len, 1, qk_rope_head_dim, dtype=dtype, device=device)
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227

        kv_c_full, k_pe_full = _quantize_dequantize_fp8_ds_mla(
            kv_c_full,
            k_pe_full.squeeze(1),
            block_size=vllm_config.cache_config.block_size,
            scale=kv_cache_scale,
        )

        q_nope, q_pe = q_c.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1)
        ql_nope = torch.einsum("qnh,lnh->qnl", q_nope, W_UK)
        q_mqa = torch.cat([ql_nope, q_pe], dim=-1)

        k_mqa = torch.cat([kv_c_full, k_pe_full], dim=-1)
        k_mqa = k_mqa.unsqueeze(1).expand(-1, num_heads, -1)
        v_mqa = kv_c_full.unsqueeze(1).expand(-1, num_heads, -1)

        attn_mask = torch.ones(q_len, s_len, dtype=torch.bool, device=device)
        causal_mask = torch.tril(torch.ones(q_len, q_len, device=device))
        attn_mask[:, ctx_len:] = causal_mask

        q_sdpa_in = q_mqa.unsqueeze(0).transpose(1, 2)
        k_sdpa_in = k_mqa.unsqueeze(0).transpose(1, 2)
        v_sdpa_in = v_mqa.unsqueeze(0).transpose(1, 2)

        sdpa_out = torch.nn.functional.scaled_dot_product_attention(
228
229
            q_sdpa_in, k_sdpa_in, v_sdpa_in, attn_mask=attn_mask, scale=scale
        )
230
231
232
233
234
235
236
237
        sdpa_out = sdpa_out.transpose(1, 2).squeeze(0)

        sdpa_out = torch.einsum("qnl,lnv->qnv", sdpa_out, W_UV)
        reference_outputs.append(sdpa_out.flatten(start_dim=-2))

        all_q_vllm.append(q_c)
        all_kv_c_vllm.append(kv_c_full[ctx_len:])
        all_k_pe_vllm.append(k_pe_full[ctx_len:])
238
239
        kv_c_contexts.append(kv_c_full[: ctx_len + 1])
        k_pe_contexts.append(k_pe_full[: ctx_len + 1])
240
241
242
243
244
245
246

    query_vllm = torch.cat(all_q_vllm, dim=0)
    kv_c_vllm = torch.cat(all_kv_c_vllm, dim=0)
    k_pe_vllm = torch.cat(all_k_pe_vllm, dim=0)
    sdpa_reference = torch.cat(reference_outputs, dim=0)

    vllm_config.cache_config.cache_dtype = kv_cache_dtype
247
    vllm_config.model_config.hf_config.index_topk = topk_tokens
248
249
250
251
252

    common_attn_metadata = create_common_attn_metadata(
        batch_spec,
        vllm_config.cache_config.block_size,
        device,
253
254
        arange_block_indices=True,
    )
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271

    kv_cache = create_and_prepopulate_kv_cache(
        kv_c_contexts=kv_c_contexts,
        k_pe_contexts=k_pe_contexts,
        block_size=vllm_config.cache_config.block_size,
        head_size=head_size,
        dtype=dtype,
        device=device,
        num_blocks=vllm_config.cache_config.num_gpu_blocks,
        common_attn_metadata=common_attn_metadata,
        randomize_blocks=False,
        kv_cache_dtype=vllm_config.cache_config.cache_dtype,
        scale=kv_cache_scale,
    )

    builder_cls = FlashMLASparseBackend.get_builder_cls()
    builder = builder_cls(kv_cache_spec, ["placeholder"], vllm_config, device)
272
273
274
    metadata = builder.build(
        common_prefix_len=0, common_attn_metadata=common_attn_metadata
    )
275

276
    starts = np.asarray(common_attn_metadata.query_start_loc_cpu, dtype=np.int32)
277
278
    seg_lengths = np.diff(starts)
    positions = np.arange(starts[-1], dtype=np.int32) - np.repeat(
279
280
        starts[:-1], seg_lengths
    )
281
282
283
284
285
286
    seq_lengths = np.asarray(common_attn_metadata.seq_lens_cpu, dtype=np.int32)
    prefix_lengths = seq_lengths - seg_lengths
    positions += np.repeat(prefix_lengths, seg_lengths)

    pos_gpu = torch.as_tensor(positions, device=device, dtype=torch.int32)
    topk = metadata.topk_tokens
287
    debug_indices = torch.arange(topk, device=device, dtype=torch.int32).unsqueeze(0)
288
    token_positions = pos_gpu.unsqueeze(1)
289
290
291
292
    causal_mask = debug_indices <= token_positions
    debug_indices = torch.where(
        causal_mask, debug_indices, torch.full_like(debug_indices, -1)
    )
293
294
295

    # FlashMLASparseImpl now reads top-k indices from the indexer-provided
    # buffer, so emulate that contract with a simple namespace mock.
296
    debug_indices = debug_indices.expand(metadata.num_actual_tokens, -1).clone()
297
298
    mock_indexer = SimpleNamespace(topk_indices_buffer=debug_indices)

299
    ok, reason = flashmla.is_flashmla_sparse_supported()
300
301
302
303
304
    if not ok:
        pytest.skip(reason)

    kv_b_proj_weight = torch.cat([W_UK, W_UV], dim=-1)
    kv_b_proj_weight = kv_b_proj_weight.view(
305
306
        kv_lora_rank, num_heads * (qk_nope_head_dim + v_head_dim)
    )
307

308
309
310
311
312
    mock_kv_b_proj = ColumnParallelLinear(
        input_size=kv_lora_rank,
        output_size=num_heads * (qk_nope_head_dim + v_head_dim),
        bias=False,
    ).to(device=device, dtype=dtype)
313
314
315
    mock_kv_b_proj.weight = torch.nn.Parameter(kv_b_proj_weight.T.contiguous())

    impl_cls = FlashMLASparseBackend.get_impl_cls()
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
    impl = impl_cls(
        num_heads=num_heads,
        head_size=head_size,
        scale=scale,
        num_kv_heads=1,
        alibi_slopes=None,
        sliding_window=None,
        kv_cache_dtype=vllm_config.cache_config.cache_dtype,
        logits_soft_cap=None,
        attn_type="decoder",
        kv_sharing_target_layer_name=None,
        q_lora_rank=None,
        kv_lora_rank=kv_lora_rank,
        qk_nope_head_dim=qk_nope_head_dim,
        qk_rope_head_dim=qk_rope_head_dim,
        qk_head_dim=qk_nope_head_dim + qk_rope_head_dim,
        v_head_dim=v_head_dim,
        kv_b_proj=mock_kv_b_proj,
        indexer=mock_indexer,
    )
336
337
338
339

    impl.process_weights_after_loading(dtype)

    layer = MockAttentionLayer(device)
340
341
342
343
    out_buffer = torch.empty(
        metadata.num_actual_tokens, num_heads * v_head_dim, dtype=dtype, device=device
    )

344
345
346
347
348
349
350
351
352
353
    with torch.inference_mode():
        backend_output = impl.forward(
            layer,
            query_vllm,
            kv_c_vllm,
            k_pe_vllm,
            kv_cache,
            metadata,
            output=out_buffer,
        )
354
355
356
357
358

    assert backend_output.shape == sdpa_reference.shape
    assert backend_output.dtype == sdpa_reference.dtype
    assert torch.isfinite(backend_output).all()

359
    torch.testing.assert_close(backend_output, sdpa_reference, rtol=0.5, atol=0.5)
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380


@pytest.mark.parametrize(
    "seq_lens,max_buf,start,expected",
    [
        # Basic split: totals per chunk ≤ max_buf
        (torch.tensor([2, 3, 4, 2]), 5, 0, [(0, 2), (2, 3), (3, 4)]),
        # Non-zero start index
        (torch.tensor([2, 3, 4, 2]), 5, 1, [(1, 2), (2, 3), (3, 4)]),
        # Exact fits should split between items when adding the next would
        # overflow
        (torch.tensor([5, 5, 5]), 5, 0, [(0, 1), (1, 2), (2, 3)]),
        # All requests fit in a single chunk
        (torch.tensor([1, 1, 1]), 10, 0, [(0, 3)]),
        # Large buffer with non-zero start
        (torch.tensor([4, 4, 4]), 100, 1, [(1, 3)]),
    ],
)
def test_split_prefill_chunks(seq_lens, max_buf, start, expected):
    out = split_prefill_chunks(seq_lens, max_buf, start)
    assert out == expected