test_sparse_mla_backends.py 21.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Unit tests for the FlashMLA sparse backend utilities."""

import math
from types import MethodType, SimpleNamespace

import numpy as np
import pytest
import torch

from tests.v1.attention.test_mla_backends import (
13
14
15
16
17
18
19
20
21
22
    BATCH_SPECS,
    BatchSpec,
    MockAttentionLayer,
    create_and_prepopulate_kv_cache,
)
from tests.v1.attention.utils import (
    create_common_attn_metadata,
    create_standard_kv_cache_spec,
    create_vllm_config,
)
23
from vllm import _custom_ops as ops
24
from vllm.config import set_current_vllm_config
25
from vllm.model_executor.layers.linear import ColumnParallelLinear
26
from vllm.platforms import current_platform
27
from vllm.utils.math_utils import cdiv
28
29
30
31
32
from vllm.v1.attention.backends.mla.flashmla_sparse import (
    FlashMLASparseBackend,
    triton_convert_req_index_to_global_index,
)
from vllm.v1.attention.backends.utils import split_prefill_chunks
33
from vllm.v1.attention.ops import flashmla
34
35
36
37
38
39
40
41
42
43
44
45

SPARSE_BACKEND_BATCH_SPECS = {
    name: BATCH_SPECS[name]
    for name in [
        "mixed_small",
        "mixed_medium",
        "small_prefill",
        "medium_prefill",
        "single_prefill",
    ]
}

46
47
48
SPARSE_BACKEND_BATCH_SPECS["large_q_prefill"] = BatchSpec(
    seq_lens=[1024] * 2, query_lens=[256] * 2
)
49
SPARSE_BACKEND_BATCH_SPECS["large_q_pure_prefill"] = BatchSpec(
50
51
    seq_lens=[256] * 2, query_lens=[256] * 2
)
52
53


Lucas Wilkinson's avatar
Lucas Wilkinson committed
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
def _float_to_e8m0_truncate(f: float) -> float:
    """Simulate SM100's float -> e8m0 -> bf16 scale conversion.

    e8m0 format only stores the exponent (power of 2).
    cudaRoundZero truncates toward zero, meaning we round down to the
    nearest power of 2.
    """
    if f <= 0:
        return 0.0
    # e8m0 = floor(log2(f)), then 2^(e8m0)
    # This is equivalent to truncating to the nearest power of 2 below f
    exp = math.floor(math.log2(f))
    return 2.0**exp


69
def _dequantize_fp8_ds_mla_entry(
Lucas Wilkinson's avatar
Lucas Wilkinson committed
70
71
72
73
74
    cache_slice: torch.Tensor,
    kv_lora_rank: int,
    rope_dim: int,
    dtype: torch.dtype,
    simulate_sm100_e8m0_scales: bool = False,
75
) -> tuple[torch.Tensor, torch.Tensor]:
Lucas Wilkinson's avatar
Lucas Wilkinson committed
76
77
78
79
80
81
    """Dequantize a single fp8_ds_mla cache entry back to latent + rope.

    Args:
        simulate_sm100_e8m0_scales: If True, simulate the SM100 kernel's
            float -> e8m0 -> bf16 scale conversion path.
    """
82
83
84

    # The first kv_lora_rank bytes store FP8 latent values with one scale per
    # 128 element tile written as float32 right after the latent payload.
85
86
    scales = cache_slice.view(torch.float32)[kv_lora_rank // 4 : kv_lora_rank // 4 + 4]
    latent = torch.empty(kv_lora_rank, dtype=torch.float16, device=cache_slice.device)
87
88
89
    for tile_idx in range(4):
        tile_start = tile_idx * 128
        tile_end = tile_start + 128
Lucas Wilkinson's avatar
Lucas Wilkinson committed
90
91
92
93
        scale_val = float(scales[tile_idx].item())
        if simulate_sm100_e8m0_scales:
            # Simulate the lossy float -> e8m0 -> bf16 conversion
            scale_val = _float_to_e8m0_truncate(scale_val)
94
95
96
        ops.convert_fp8(
            latent[tile_start:tile_end],
            cache_slice[tile_start:tile_end],
Lucas Wilkinson's avatar
Lucas Wilkinson committed
97
            scale_val,
98
99
            kv_dtype="fp8",
        )
100
101
102
    latent = latent.to(dtype)

    rope_offset = kv_lora_rank // 2 + 8
103
    rope_vals = cache_slice.view(dtype)[rope_offset : rope_offset + rope_dim]
104
105
106
107
    return latent, rope_vals.clone()


def _quantize_dequantize_fp8_ds_mla(
Lucas Wilkinson's avatar
Lucas Wilkinson committed
108
109
110
111
112
    kv_c: torch.Tensor,
    k_pe: torch.Tensor,
    block_size: int,
    scale: torch.Tensor,
    simulate_sm100_e8m0_scales: bool = False,
113
) -> tuple[torch.Tensor, torch.Tensor]:
Lucas Wilkinson's avatar
Lucas Wilkinson committed
114
115
116
117
118
119
    """Round-trip kv_c/k_pe though the fp8_ds_mla cache layout.

    Args:
        simulate_sm100_e8m0_scales: If True, simulate the SM100 kernel's
            float -> e8m0 -> bf16 scale conversion in dequantization.
    """
120
121
122
123
124
125
126
127
128
129

    if kv_c.numel() == 0:
        return kv_c.clone(), k_pe.clone()

    kv_lora_rank = kv_c.shape[-1]
    rope_dim = k_pe.shape[-1]
    num_tokens = kv_c.shape[0]
    num_blocks = max(1, math.ceil(num_tokens / block_size))
    entry_size = kv_lora_rank + 4 * 4 + 2 * rope_dim

130
131
132
133
134
135
136
137
    tmp_cache = torch.zeros(
        num_blocks, block_size, entry_size, dtype=torch.uint8, device=kv_c.device
    )
    slot_mapping = torch.arange(num_tokens, dtype=torch.long, device=kv_c.device)

    ops.concat_and_cache_mla(
        kv_c, k_pe, tmp_cache, slot_mapping, kv_cache_dtype="fp8_ds_mla", scale=scale
    )
138
139
140
141
142
143
144
145
146
147

    dequant_kv_c = torch.empty_like(kv_c)
    dequant_k_pe = torch.empty_like(k_pe)

    for token_idx in range(num_tokens):
        slot = slot_mapping[token_idx].item()
        block_idx = slot // block_size
        block_offset = slot % block_size
        cache_slice = tmp_cache[block_idx, block_offset]
        latent, rope_vals = _dequantize_fp8_ds_mla_entry(
Lucas Wilkinson's avatar
Lucas Wilkinson committed
148
149
150
151
152
            cache_slice,
            kv_lora_rank,
            rope_dim,
            kv_c.dtype,
            simulate_sm100_e8m0_scales=simulate_sm100_e8m0_scales,
153
        )
154
155
156
157
158
159
160
161
        dequant_kv_c[token_idx] = latent
        dequant_k_pe[token_idx] = rope_vals

    return dequant_kv_c, dequant_k_pe


@pytest.mark.parametrize("batch_name", list(SPARSE_BACKEND_BATCH_SPECS.keys()))
@pytest.mark.parametrize("kv_cache_dtype", ["fp8_ds_mla", "auto"])
162
@pytest.mark.parametrize("tensor_parallel_size", [1, 2, 4])
163
164
165
166
@pytest.mark.skipif(
    torch.cuda.get_device_capability() < (9, 0),
    reason="FlashMLASparseBackend requires CUDA 9.0 or higher",
)
167
def test_sparse_backend_decode_correctness(
168
169
170
171
172
173
    default_vllm_config,
    dist_init,
    batch_name,
    kv_cache_dtype,
    tensor_parallel_size,
    workspace_init,
174
):
175
176
177
    if current_platform.is_rocm():
        pytest.skip("ROCm does not support fp8_ds_mla data type for kv cache.")

178
179
180
181
182
183
184
185
186
    if not torch.cuda.is_available():
        pytest.skip("CUDA is required for sparse MLA decode test")

    device = torch.device("cuda")
    dtype = torch.bfloat16

    batch_spec = SPARSE_BACKEND_BATCH_SPECS[batch_name]

    # Model hyper-parameters (kept intentionally small for the unit test)
Lucas Wilkinson's avatar
Lucas Wilkinson committed
187
188
189
190
    total_num_heads = 128
    # Compute per-rank heads for simulated TP
    num_heads = max(1, total_num_heads // tensor_parallel_size)

191
192
193
194
195
196
197
198
199
200
201
    kv_lora_rank = 512
    qk_nope_head_dim = 128
    qk_rope_head_dim = 64
    v_head_dim = 128
    head_size = kv_lora_rank + qk_rope_head_dim
    topk_tokens = 2048

    max_seqlen = max(batch_spec.seq_lens)
    total_cache_tokens = sum(batch_spec.seq_lens)
    block_size = 64

202
203
    # Note: We use TP=1 to avoid multi-GPU requirements in CI.
    # The test simulates head partitioning via mocked methods below.
204
205
    vllm_config = create_vllm_config(
        model_name="deepseek-ai/DeepSeek-V2-Lite-Chat",
206
        tensor_parallel_size=1,
207
        max_model_len=max_seqlen,
208
209
        num_gpu_blocks=max(2048, cdiv(total_cache_tokens, block_size) + 1),
        block_size=block_size,
210
211
212
213
        hf_config_override={
            "index_topk": topk_tokens,
            "attn_module_list_cfg": [{"topk_tokens": topk_tokens}],
        },
214
    )
215
216
217
218
219
220
221
222
223
224
225
    model_config = vllm_config.model_config
    model_config.hf_text_config = SimpleNamespace(
        q_lora_rank=None,
        kv_lora_rank=kv_lora_rank,
        qk_nope_head_dim=qk_nope_head_dim,
        qk_rope_head_dim=qk_rope_head_dim,
        v_head_dim=v_head_dim,
        model_type="deepseek_v2",
    )
    model_config.dtype = dtype
    model_config.get_num_attention_heads = MethodType(
Lucas Wilkinson's avatar
Lucas Wilkinson committed
226
        lambda self, parallel_config: num_heads,
227
        model_config,
228
229
230
231
232
233
    )
    model_config.get_num_kv_heads = MethodType(
        lambda self, parallel_config: 1, model_config
    )
    model_config.get_head_size = MethodType(lambda self: head_size, model_config)
    model_config.get_sliding_window = MethodType(lambda self: None, model_config)
234
235
236
237
238
239
240
241

    kv_cache_spec = create_standard_kv_cache_spec(vllm_config)

    torch.manual_seed(0)

    scale = 1.0 / math.sqrt(head_size)

    # Shared MLA projection weights to keep reference and backend in sync
Lucas Wilkinson's avatar
Lucas Wilkinson committed
242
    W_UK = torch.rand(
243
244
        kv_lora_rank, num_heads, qk_nope_head_dim, dtype=dtype, device=device
    )
Lucas Wilkinson's avatar
Lucas Wilkinson committed
245
    W_UV = torch.rand(kv_lora_rank, num_heads, v_head_dim, dtype=dtype, device=device)
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261

    # Build synthetic decode-only workload
    seq_lens = batch_spec.seq_lens
    query_lens = batch_spec.query_lens

    all_q_vllm, all_kv_c_vllm, all_k_pe_vllm = [], [], []
    kv_c_contexts, k_pe_contexts = [], []
    reference_outputs = []

    kv_cache_scale = torch.tensor(1.0, dtype=torch.float32, device=device)

    for i in range(batch_spec.batch_size):
        s_len = seq_lens[i]
        q_len = query_lens[i]
        ctx_len = s_len - q_len

262
263
264
265
266
267
268
        q_c = torch.rand(
            q_len,
            num_heads,
            qk_nope_head_dim + qk_rope_head_dim,
            dtype=dtype,
            device=device,
        )
269
        kv_c_full = torch.rand(s_len, kv_lora_rank, dtype=dtype, device=device)
270
        k_pe_full = torch.rand(s_len, 1, qk_rope_head_dim, dtype=dtype, device=device)
271

Lucas Wilkinson's avatar
Lucas Wilkinson committed
272
273
274
        # SM100 (Blackwell) uses float -> e8m0 -> bf16 scale conversion
        # which truncates scales to powers of 2. Simulate this in reference.
        is_sm100 = torch.cuda.get_device_capability()[0] >= 10
275
276
277
278
279
        kv_c_full, k_pe_full = _quantize_dequantize_fp8_ds_mla(
            kv_c_full,
            k_pe_full.squeeze(1),
            block_size=vllm_config.cache_config.block_size,
            scale=kv_cache_scale,
Lucas Wilkinson's avatar
Lucas Wilkinson committed
280
            simulate_sm100_e8m0_scales=is_sm100,
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
        )

        q_nope, q_pe = q_c.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1)
        ql_nope = torch.einsum("qnh,lnh->qnl", q_nope, W_UK)
        q_mqa = torch.cat([ql_nope, q_pe], dim=-1)

        k_mqa = torch.cat([kv_c_full, k_pe_full], dim=-1)
        k_mqa = k_mqa.unsqueeze(1).expand(-1, num_heads, -1)
        v_mqa = kv_c_full.unsqueeze(1).expand(-1, num_heads, -1)

        attn_mask = torch.ones(q_len, s_len, dtype=torch.bool, device=device)
        causal_mask = torch.tril(torch.ones(q_len, q_len, device=device))
        attn_mask[:, ctx_len:] = causal_mask

        q_sdpa_in = q_mqa.unsqueeze(0).transpose(1, 2)
        k_sdpa_in = k_mqa.unsqueeze(0).transpose(1, 2)
        v_sdpa_in = v_mqa.unsqueeze(0).transpose(1, 2)

        sdpa_out = torch.nn.functional.scaled_dot_product_attention(
300
301
            q_sdpa_in, k_sdpa_in, v_sdpa_in, attn_mask=attn_mask, scale=scale
        )
302
303
304
305
306
307
308
309
        sdpa_out = sdpa_out.transpose(1, 2).squeeze(0)

        sdpa_out = torch.einsum("qnl,lnv->qnv", sdpa_out, W_UV)
        reference_outputs.append(sdpa_out.flatten(start_dim=-2))

        all_q_vllm.append(q_c)
        all_kv_c_vllm.append(kv_c_full[ctx_len:])
        all_k_pe_vllm.append(k_pe_full[ctx_len:])
310
311
        kv_c_contexts.append(kv_c_full[: ctx_len + 1])
        k_pe_contexts.append(k_pe_full[: ctx_len + 1])
312
313
314
315
316
317
318

    query_vllm = torch.cat(all_q_vllm, dim=0)
    kv_c_vllm = torch.cat(all_kv_c_vllm, dim=0)
    k_pe_vllm = torch.cat(all_k_pe_vllm, dim=0)
    sdpa_reference = torch.cat(reference_outputs, dim=0)

    vllm_config.cache_config.cache_dtype = kv_cache_dtype
319
    vllm_config.model_config.hf_config.index_topk = topk_tokens
320
321
322
323
324

    common_attn_metadata = create_common_attn_metadata(
        batch_spec,
        vllm_config.cache_config.block_size,
        device,
325
326
        arange_block_indices=True,
    )
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343

    kv_cache = create_and_prepopulate_kv_cache(
        kv_c_contexts=kv_c_contexts,
        k_pe_contexts=k_pe_contexts,
        block_size=vllm_config.cache_config.block_size,
        head_size=head_size,
        dtype=dtype,
        device=device,
        num_blocks=vllm_config.cache_config.num_gpu_blocks,
        common_attn_metadata=common_attn_metadata,
        randomize_blocks=False,
        kv_cache_dtype=vllm_config.cache_config.cache_dtype,
        scale=kv_cache_scale,
    )

    builder_cls = FlashMLASparseBackend.get_builder_cls()
    builder = builder_cls(kv_cache_spec, ["placeholder"], vllm_config, device)
344
345
346
    metadata = builder.build(
        common_prefix_len=0, common_attn_metadata=common_attn_metadata
    )
347

348
    starts = np.asarray(common_attn_metadata.query_start_loc_cpu, dtype=np.int32)
349
350
    seg_lengths = np.diff(starts)
    positions = np.arange(starts[-1], dtype=np.int32) - np.repeat(
351
352
        starts[:-1], seg_lengths
    )
353
    seq_lengths = np.asarray(common_attn_metadata.seq_lens.cpu(), dtype=np.int32)
354
355
356
357
358
    prefix_lengths = seq_lengths - seg_lengths
    positions += np.repeat(prefix_lengths, seg_lengths)

    pos_gpu = torch.as_tensor(positions, device=device, dtype=torch.int32)
    topk = metadata.topk_tokens
359
    debug_indices = torch.arange(topk, device=device, dtype=torch.int32).unsqueeze(0)
360
    token_positions = pos_gpu.unsqueeze(1)
361
362
363
364
    causal_mask = debug_indices <= token_positions
    debug_indices = torch.where(
        causal_mask, debug_indices, torch.full_like(debug_indices, -1)
    )
365
366
367

    # FlashMLASparseImpl now reads top-k indices from the indexer-provided
    # buffer, so emulate that contract with a simple namespace mock.
368
    debug_indices = debug_indices.expand(metadata.num_actual_tokens, -1).clone()
369
370
    mock_indexer = SimpleNamespace(topk_indices_buffer=debug_indices)

371
    ok, reason = flashmla.is_flashmla_sparse_supported()
372
373
374
375
376
    if not ok:
        pytest.skip(reason)

    kv_b_proj_weight = torch.cat([W_UK, W_UV], dim=-1)
    kv_b_proj_weight = kv_b_proj_weight.view(
377
378
        kv_lora_rank, num_heads * (qk_nope_head_dim + v_head_dim)
    )
379

380
381
382
383
384
    mock_kv_b_proj = ColumnParallelLinear(
        input_size=kv_lora_rank,
        output_size=num_heads * (qk_nope_head_dim + v_head_dim),
        bias=False,
    ).to(device=device, dtype=dtype)
385
386
387
    mock_kv_b_proj.weight = torch.nn.Parameter(kv_b_proj_weight.T.contiguous())

    impl_cls = FlashMLASparseBackend.get_impl_cls()
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
    with set_current_vllm_config(vllm_config):
        impl = impl_cls(
            num_heads=num_heads,
            head_size=head_size,
            scale=scale,
            num_kv_heads=1,
            alibi_slopes=None,
            sliding_window=None,
            kv_cache_dtype=vllm_config.cache_config.cache_dtype,
            logits_soft_cap=None,
            attn_type="decoder",
            kv_sharing_target_layer_name=None,
            q_lora_rank=None,
            kv_lora_rank=kv_lora_rank,
            qk_nope_head_dim=qk_nope_head_dim,
            qk_rope_head_dim=qk_rope_head_dim,
            qk_head_dim=qk_nope_head_dim + qk_rope_head_dim,
            v_head_dim=v_head_dim,
            kv_b_proj=mock_kv_b_proj,
            indexer=mock_indexer,
        )
409

410
        impl.process_weights_after_loading(dtype)
411
412

    layer = MockAttentionLayer(device)
413
414
415
416
    out_buffer = torch.empty(
        metadata.num_actual_tokens, num_heads * v_head_dim, dtype=dtype, device=device
    )

417
418
419
420
421
422
423
424
425
426
    with torch.inference_mode():
        backend_output = impl.forward(
            layer,
            query_vllm,
            kv_c_vllm,
            k_pe_vllm,
            kv_cache,
            metadata,
            output=out_buffer,
        )
427
428
429
430
431

    assert backend_output.shape == sdpa_reference.shape
    assert backend_output.dtype == sdpa_reference.dtype
    assert torch.isfinite(backend_output).all()

Lucas Wilkinson's avatar
Lucas Wilkinson committed
432
433
434
435
436
437
    # FP8 quantization introduces some error, but should be within reasonable bounds
    # BF16 (auto) should be very accurate, FP8 allows slightly more tolerance
    if kv_cache_dtype == "fp8_ds_mla":
        torch.testing.assert_close(backend_output, sdpa_reference, rtol=0.05, atol=0.05)
    else:
        torch.testing.assert_close(backend_output, sdpa_reference, rtol=0.01, atol=0.01)
438
439


440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
def _triton_convert_reference_impl(
    req_ids: torch.Tensor,
    block_table: torch.Tensor,
    token_indices: torch.Tensor,
    block_size: int,
    num_topk_tokens: int,
    HAS_PREFILL_WORKSPACE: bool = False,
    prefill_workspace_request_ids: torch.Tensor | None = None,
    prefill_workspace_starts: torch.Tensor | None = None,
) -> torch.Tensor:
    """Reference implementation for triton_convert_req_index_to_global_index."""
    num_tokens = req_ids.shape[0]
    max_blocks_per_req = block_table.shape[1]
    result = torch.empty(
        num_tokens, num_topk_tokens, dtype=torch.int32, device=req_ids.device
    )

    for token_id in range(num_tokens):
        req_id = req_ids[token_id].item()

        # Determine if this token uses workspace or paged cache
        use_prefill_workspace = False
        workspace_start = 0
        if HAS_PREFILL_WORKSPACE and prefill_workspace_request_ids is not None:
            assert prefill_workspace_starts is not None
            prefill_req_id = prefill_workspace_request_ids[token_id].item()
            if prefill_req_id >= 0:
                use_prefill_workspace = True
                workspace_start = prefill_workspace_starts[prefill_req_id].item()

        for idx_id in range(num_topk_tokens):
            token_idx = token_indices[token_id, idx_id].item()

            if token_idx == -1:
                result[token_id, idx_id] = -1
            elif use_prefill_workspace:
                # Prefill + using prefill workspace: map to workspace offset
                result[token_id, idx_id] = workspace_start + token_idx
            else:
                # Decode: map to paged cache
                block_id = token_idx // block_size
                if block_id >= max_blocks_per_req:
                    result[token_id, idx_id] = -1
                else:
                    block_num = block_table[req_id, block_id].item()
                    offset = token_idx % block_size
                    result[token_id, idx_id] = block_num * block_size + offset

    return result


@pytest.mark.parametrize("block_size", [16, 64, 128])
@pytest.mark.parametrize("num_topk_tokens", [128, 256, 512])
@pytest.mark.skipif(
    torch.cuda.get_device_capability() < (9, 0),
    reason="FlashMLASparseBackend requires CUDA 9.0 or higher",
)
def test_triton_convert_req_index_to_global_index_decode_only(
    block_size, num_topk_tokens
):
    device = torch.device("cuda")
    num_tokens = 8
    num_requests = 4
    max_blocks_per_req = 10

    req_id = torch.randint(
        0, num_requests, (num_tokens,), dtype=torch.int32, device=device
    )
    block_table = torch.randint(
        0, 100, (num_requests, max_blocks_per_req), dtype=torch.int32, device=device
    )

    token_indices = torch.randint(
        0,
        block_size * max_blocks_per_req,
        (num_tokens, num_topk_tokens),
        dtype=torch.int32,
        device=device,
    )

    # Set some to -1 to test masking
    token_indices[0, :10] = -1
    token_indices[3, 50:60] = -1

    # Set some to out of bounds
    token_indices[2, 100:110] = max_blocks_per_req * block_size
    token_indices[6, 150:160] = max_blocks_per_req * block_size

    result = triton_convert_req_index_to_global_index(
        req_id,
        block_table,
        token_indices,
        BLOCK_SIZE=block_size,
        NUM_TOPK_TOKENS=num_topk_tokens,
    )

    reference_result = _triton_convert_reference_impl(
        req_id,
        block_table,
        token_indices,
        block_size,
        num_topk_tokens,
    )

    torch.testing.assert_close(result, reference_result, rtol=0, atol=0)


@pytest.mark.parametrize("block_size", [16])
@pytest.mark.skipif(
    torch.cuda.get_device_capability() < (9, 0),
    reason="FlashMLASparseBackend requires CUDA 9.0 or higher",
)
def test_triton_convert_req_index_to_global_index_with_prefill_workspace(block_size):
    device = torch.device("cuda")
    num_requests = 4
    max_blocks_per_req = 8
    num_topk_tokens = 128

    # First 6 tokens are decode (reqs 0, 1), last 6 are prefill (reqs 2, 3)
    req_id = torch.tensor(
        [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3], dtype=torch.int32, device=device
    )
    prefill_workspace_request_ids = torch.tensor(
        [-1, -1, -1, -1, -1, -1, 0, 0, 0, 1, 1, 1], dtype=torch.int32, device=device
    )

    # Workspace starts for the 2 prefill reqs: req 2 starts at 0, req 3 starts at 100
    prefill_workspace_starts = torch.tensor([0, 100], dtype=torch.int32, device=device)

    block_table = torch.randint(
        0, 50, (num_requests, max_blocks_per_req), dtype=torch.int32, device=device
    )
    token_indices = torch.randint(
        0,
        block_size * max_blocks_per_req,
        (req_id.shape[0], num_topk_tokens),
        dtype=torch.int32,
        device=device,
    )

    # Set some to -1 to test masking
    token_indices[0, :10] = -1
    token_indices[3, 50:60] = -1

    # Set some to out of bounds
    token_indices[2, 100:110] = max_blocks_per_req * block_size
    token_indices[6, 150:160] = max_blocks_per_req * block_size

    result = triton_convert_req_index_to_global_index(
        req_id,
        block_table,
        token_indices,
        BLOCK_SIZE=block_size,
        NUM_TOPK_TOKENS=num_topk_tokens,
        HAS_PREFILL_WORKSPACE=True,
        prefill_workspace_request_ids=prefill_workspace_request_ids,
        prefill_workspace_starts=prefill_workspace_starts,
    )

    reference_result = _triton_convert_reference_impl(
        req_id,
        block_table,
        token_indices,
        block_size,
        num_topk_tokens,
        HAS_PREFILL_WORKSPACE=True,
        prefill_workspace_request_ids=prefill_workspace_request_ids,
        prefill_workspace_starts=prefill_workspace_starts,
    )

    torch.testing.assert_close(result, reference_result, rtol=0, atol=0)


613
@pytest.mark.parametrize(
614
    "seq_lens,max_buf,expected",
615
616
    [
        # Basic split: totals per chunk ≤ max_buf
617
618
619
        (torch.tensor([2, 3, 4, 2]), 5, [(0, 2), (2, 3), (3, 4)]),
        # Exact fits should split between items when adding the next would overflow
        (torch.tensor([5, 5, 5]), 5, [(0, 1), (1, 2), (2, 3)]),
620
        # All requests fit in a single chunk
621
622
623
        (torch.tensor([1, 1, 1]), 10, [(0, 3)]),
        # Large buffer
        (torch.tensor([4, 4, 4]), 100, [(0, 3)]),
624
625
    ],
)
626
627
def test_split_prefill_chunks(seq_lens, max_buf, expected):
    out = split_prefill_chunks(seq_lens, max_buf)
628
    assert out == expected