"vllm/vscode:/vscode.git/clone" did not exist on "4f044b1d67964e53587a4d0c7f00233a04b7be4e"
test_mamba_prefix_cache.py 30.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import multiprocessing as mp
import os
import traceback
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any

import datasets
import pytest
import torch

14
from tests.utils import create_new_process_for_each_test
15
16
from vllm import LLM, SamplingParams, TokensPrompt
from vllm.config import CacheConfig
17
from vllm.distributed import cleanup_dist_env_and_memory
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from vllm.model_executor.layers.mamba.mamba_utils import MambaStateCopyFunc
from vllm.sequence import IntermediateTensors
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.engine.core_client import InprocClient
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.outputs import SamplerOutput
from vllm.v1.request import Request
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.worker import mamba_utils
from vllm.v1.worker.gpu_input_batch import CachedRequestState
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
from vllm.v1.worker.lora_model_runner_mixin import GPUInputBatch
from vllm.v1.worker.mamba_utils import get_mamba_groups


@dataclass
class StepAction:
    num_computed_tokens_start: int
    num_scheduled_tokens: int
    kv_cache_block_ids: list[int]  # [] to follow last step
    preprocess_copy_idx: tuple[int, int]  # -1, -1 for no copy
    postprocess_copy_idx: tuple[int, int]  # -1, -1 for no copy


num_speculative_tokens = 3

num_accepted_tokens = 1
prompt_token_ids: list[int] = []
MODEL = "Qwen/Qwen3-Next-80B-A3B-Instruct-FP8"
BLOCK_SIZE = 560
NUM_HIDDEN_LAYERS = 1
cur_step_action_idx = 0
cur_step_action: StepAction | None = None
step_actions: list[StepAction] = []


def get_fake_sample_fn() -> SamplerOutput:
    def fake_sample_fn(
        self: GPUModelRunner,
        logits: torch.Tensor | None,
        spec_decode_metadata: SpecDecodeMetadata | None,
    ) -> SamplerOutput:
        assert logits is not None
        num_computed_tokens_cpu_tensor = self.input_batch.num_computed_tokens_cpu_tensor
        num_computed_tokens = num_computed_tokens_cpu_tensor[0].item()
        if num_computed_tokens < self.input_batch.num_prompt_tokens[0].item():
            first_token_id_index = self.input_batch.num_prompt_tokens[0].item()
        else:
            first_token_id_index = num_computed_tokens + 1
        if spec_decode_metadata is None:
            return SamplerOutput(
                sampled_token_ids=torch.tensor(
                    [[prompt_token_ids[first_token_id_index]]],
                    device="cuda",
                    dtype=torch.int32,
                ),
                logprobs_tensors=None,
            )
        accpeted_tokens = prompt_token_ids[
            first_token_id_index : first_token_id_index
            + min(num_accepted_tokens, logits.shape[0])
        ]
83
        sampled_token_ids = accpeted_tokens
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
        return SamplerOutput(
            sampled_token_ids=torch.tensor(
                [sampled_token_ids], device="cuda", dtype=torch.int32
            ),
            logprobs_tensors=None,
        )

    return fake_sample_fn


def get_fake_propose_draft_token_ids_fn():
    def fake_propose_draft_token_ids_fn(
        self: GPUModelRunner,
        scheduler_output: SchedulerOutput,
        sampled_token_ids: torch.Tensor | list[list[int]],
        sampling_metadata: SamplingMetadata,
        hidden_states: torch.Tensor,
        sample_hidden_states: torch.Tensor,
        aux_hidden_states: list[torch.Tensor] | None,
        spec_decode_metadata: SpecDecodeMetadata | None,
        common_attn_metadata: CommonAttentionMetadata,
105
        slot_mappings: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None,
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
    ) -> list[list[int]]:
        num_computed_tokens_cpu_tensor = self.input_batch.num_computed_tokens_cpu_tensor
        num_computed_tokens = num_computed_tokens_cpu_tensor[0].item()
        if (
            self.input_batch.num_tokens_no_spec[0].item()
            <= self.input_batch.num_prompt_tokens[0].item()
        ):
            first_token_id_index = self.input_batch.num_prompt_tokens[0].item()
        else:
            first_token_id_index = (
                num_computed_tokens + 1
            )  # bonus token isn't considered as computed
        first_token_id_index += self.input_batch.num_accepted_tokens_cpu[0].item()
        proposed_draft_token_ids = [
            prompt_token_ids[
                first_token_id_index : first_token_id_index + num_speculative_tokens
            ]
        ]
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141

        next_token_ids = torch.tensor(
            prompt_token_ids[
                first_token_id_index - 1 : first_token_id_index
                - 1
                + num_accepted_tokens
            ],
            device="cuda",
            dtype=torch.int32,
        )

        valid_sampled_tokens_count = torch.tensor(
            [num_accepted_tokens], device="cuda", dtype=torch.int32
        )

        self._copy_valid_sampled_token_count(next_token_ids, valid_sampled_tokens_count)

        return torch.tensor(proposed_draft_token_ids, device="cuda", dtype=torch.int32)
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200

    return fake_propose_draft_token_ids_fn


def get_fake_step_action_fn(original_step_action_fn: Callable):
    def fake_get_output(self: InprocClient):
        global cur_step_action_idx
        global cur_step_action
        if cur_step_action_idx < len(step_actions):
            cur_step_action = step_actions[cur_step_action_idx]
            cur_step_action_idx += 1
        else:
            cur_step_action = None
        print(f"cur_step_action: {cur_step_action_idx=} {cur_step_action=}")
        return original_step_action_fn(self)

    return fake_get_output


def get_fake_allocate_slots_fn(original_allocate_slots_fn: Callable):
    def fake_allocate_slots_fn(
        self: KVCacheManager,
        request: Request,
        num_new_tokens: int,
        num_new_computed_tokens: int = 0,
        new_computed_blocks: KVCacheBlocks | None = None,
        num_lookahead_tokens: int = 0,
        num_external_computed_tokens: int = 0,
        delay_cache_blocks: bool = False,
        num_encoder_tokens: int = 0,
    ):
        ret = original_allocate_slots_fn(
            self,
            request,
            num_new_tokens,
            num_new_computed_tokens,
            new_computed_blocks,
            num_lookahead_tokens,
            num_external_computed_tokens,
            delay_cache_blocks,
            num_encoder_tokens,
        )
        if cur_step_action is not None:
            cur_block_ids = self.coordinator.single_type_managers[0].req_to_blocks[
                request.request_id
            ]
            not_null_block_flags = [not block.is_null for block in cur_block_ids]
            block_ids = [1 if block else 0 for block in not_null_block_flags]
            assert block_ids == cur_step_action.kv_cache_block_ids
        return ret

    return fake_allocate_slots_fn


mamba_kv_cache_dict = {}


def get_fake_execute_model_fn(original_execute_model_fn: Callable):
    last_num_computed_tokens = 0
201
    num_prompt_tokens = None
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218

    def fake_execute_model_fn(
        self: GPUModelRunner,
        scheduler_output: SchedulerOutput,
        intermediate_tensors: IntermediateTensors | None = None,
    ):
        if cur_step_action is not None:
            num_scheduled_tokens = next(
                iter(scheduler_output.num_scheduled_tokens.values())
            )
            assert num_scheduled_tokens == cur_step_action.num_scheduled_tokens
        mamba_group_ids, mamba_spec = get_mamba_groups(self.kv_cache_config)
        mamba_group_id = mamba_group_ids[0]
        mamba_layer_name = self.kv_cache_config.kv_cache_groups[
            mamba_group_id
        ].layer_names[0]
        nonlocal last_num_computed_tokens
219
220
221
222
223
224
225
226
227
228
229
        nonlocal num_prompt_tokens

        if (
            len(scheduler_output.scheduled_new_reqs) > 0
            and scheduler_output.scheduled_new_reqs[0].prompt_token_ids is not None
        ):
            # record number of prompt tokens
            num_prompt_tokens = len(
                scheduler_output.scheduled_new_reqs[0].prompt_token_ids
            )

230
231
232
233
        if len(scheduler_output.scheduled_cached_reqs.req_ids) > 0:
            num_computed_tokens = (
                scheduler_output.scheduled_cached_reqs.num_computed_tokens[0]
            )
234
235
236
237
238
239
240
241
242
            if (
                self.num_spec_tokens
                and num_prompt_tokens is not None
                and num_computed_tokens > num_prompt_tokens
            ):
                # NOTE (tdoublep) with async scheduling, the scheduler does not have an
                # accurate measure of the number of computed tokens; we need to subtract
                # the number of reject tokens from the previous timestep.
                num_computed_tokens -= num_speculative_tokens + 1 - num_accepted_tokens
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
            if (
                num_computed_tokens // BLOCK_SIZE
                > last_num_computed_tokens // BLOCK_SIZE
            ):
                # generated a new aligned block in this step
                block_idx = num_computed_tokens // mamba_spec.block_size - 1
                block_id = (
                    self.input_batch.block_table.block_tables[mamba_group_id]
                    .block_table.cpu[0, block_idx]
                    .item()
                )
                if block_id != 0:
                    kv_cache = self.compilation_config.static_forward_context[
                        mamba_layer_name
                    ].kv_cache
                    mamba_kv_cache_dict[
                        num_computed_tokens - num_computed_tokens % BLOCK_SIZE
                    ] = (
                        kv_cache[0][0][block_id].clone(),
                        kv_cache[0][1][block_id].clone(),
                    )

            last_num_computed_tokens = num_computed_tokens
        else:
            last_num_computed_tokens = 0

        ret = original_execute_model_fn(self, scheduler_output, intermediate_tensors)

        if cur_step_action is not None:
            assert (
                cur_step_action.num_computed_tokens_start
                == self.input_batch.num_computed_tokens_cpu[0].item()
            )

        return ret

    return fake_execute_model_fn


def get_fake_process_mamba_fn(
    original_preprocess_mamba_fn: Callable,
    original_post_process_mamba_fn: Callable,
    original_copy_fn: Callable,
):
    copy_info: tuple[list[int], list[int], list[int]] | None = None

    def check_copy_info(
        action: tuple[int, int],
        kv_cache_config: KVCacheConfig,
        forward_context: dict[str, Any],
        input_batch: GPUInputBatch,
    ):
        assert copy_info is not None
        if action == (-1, -1):
            assert len(copy_info[0]) == len(copy_info[1]) == len(copy_info[2]) == 0
        else:
            assert len(copy_info[0]) == len(copy_info[1]) == len(copy_info[2]) == 2
            mamba_group_ids, mamba_spec = get_mamba_groups(kv_cache_config)
            mamba_group_id = mamba_group_ids[0]
            mamba_layer_name = kv_cache_config.kv_cache_groups[
                mamba_group_id
            ].layer_names[0]
            mamba_kv_cache = forward_context[mamba_layer_name].kv_cache[0][-1]
            mamba_block_table = input_batch.block_table.block_tables[
                mamba_group_id
            ].block_table.cpu[0]
            expected_temporal_src = mamba_kv_cache[
                mamba_block_table[action[0]]
            ].data_ptr()
            expected_temporal_dest = mamba_kv_cache[
                mamba_block_table[action[1]]
            ].data_ptr()
            # -1 is qwen3-next's temporal. We skip checking conv as it is more complex.
            assert copy_info[0][-1] == expected_temporal_src
            assert copy_info[1][-1] == expected_temporal_dest

    def fake_preprocess_mamba_fn(
        scheduler_output: SchedulerOutput,
        kv_cache_config: KVCacheConfig,
        cache_config: CacheConfig,
        mamba_state_idx: dict[str, int],
        input_batch: GPUInputBatch,
        requests: dict[str, CachedRequestState],
        forward_context: dict[str, Any],
        mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...],
    ):
        nonlocal copy_info
        copy_info = None
        ret = original_preprocess_mamba_fn(
            scheduler_output,
            kv_cache_config,
            cache_config,
            mamba_state_idx,
            input_batch,
            requests,
            forward_context,
            mamba_state_copy_funcs,
        )
        if cur_step_action is not None:
            check_copy_info(
                cur_step_action.preprocess_copy_idx,
                kv_cache_config,
                forward_context,
                input_batch,
            )
        return ret

    def fake_post_process_mamba_fn(
        scheduler_output: SchedulerOutput,
        kv_cache_config: KVCacheConfig,
        input_batch: GPUInputBatch,
        requests: dict[str, CachedRequestState],
        mamba_state_idx: dict[str, int],
        forward_context: dict[str, Any],
        mamba_state_copy_funcs: tuple[MambaStateCopyFunc, ...],
    ):
        nonlocal copy_info
        copy_info = None
        ret = original_post_process_mamba_fn(
            scheduler_output,
            kv_cache_config,
            input_batch,
            requests,
            mamba_state_idx,
            forward_context,
            mamba_state_copy_funcs,
        )
        if cur_step_action is not None:
            check_copy_info(
                cur_step_action.postprocess_copy_idx,
                kv_cache_config,
                forward_context,
                input_batch,
            )
        return ret

    def fake_copy_fn(
        src_state_list: list[int],
        dest_state_list: list[int],
        num_elements_list: list[int],
    ):
        nonlocal copy_info
        assert copy_info is None
        copy_info = (src_state_list, dest_state_list, num_elements_list)
        return original_copy_fn(
            src_state_list,
            dest_state_list,
            num_elements_list,
        )

    return fake_preprocess_mamba_fn, fake_post_process_mamba_fn, fake_copy_fn


def run_ref_mamba_state_in_subprocess() -> None:
    ctx = mp.get_context("spawn")
    proc = ctx.Process(target=_run_ref_mamba_state_worker)
    proc.start()
    proc.join(timeout=600)
    if proc.exitcode != 0:
        raise RuntimeError(f"Ref mamba state process exited with code {proc.exitcode}.")


def _run_ref_mamba_state_worker():
    try:
        os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
        num_generated_tokens = 8000
        num_prompt_tokens = 500
        sampling_params = SamplingParams(
            temperature=0.0, max_tokens=num_generated_tokens
        )
        prompt_dataset = datasets.load_dataset("heheda/a_long_article")
        full_prompt = prompt_dataset["train"][0]["text"]
        fake_execute_model_fn = get_fake_execute_model_fn(GPUModelRunner.execute_model)
        GPUModelRunner.execute_model = fake_execute_model_fn
        fake_sample_fn = get_fake_sample_fn()
        GPUModelRunner._sample = fake_sample_fn
        engine = LLM(
            model=MODEL,
            block_size=BLOCK_SIZE,
            hf_overrides={"num_hidden_layers": NUM_HIDDEN_LAYERS},
            seed=42,
        )
        global prompt_token_ids
        prompt_token_ids = engine.get_tokenizer().encode(full_prompt)
        print(f"Token IDs length: {len(prompt_token_ids)}")

        _outputs = engine.generate(
            [TokensPrompt(prompt_token_ids=prompt_token_ids[:num_prompt_tokens])],
            sampling_params,
        )
        # ref_mamba_kv_cache_dict = torch.load("mamba_kv_cache_dict.pth")
        # check_mamba_state_equal(ref_mamba_kv_cache_dict, mamba_kv_cache_dict)
        # torch.save(mamba_kv_cache_dict, "mamba_kv_cache_dict.pth")
        cpu_state_ref = {
            key: tuple(tensor.detach().cpu() for tensor in tensors)
            for key, tensors in mamba_kv_cache_dict.items()
        }
        torch.save(cpu_state_ref, "mamba_kv_cache_dict_ref.pth")
        mamba_kv_cache_dict.clear()
442
443
444
        del engine
        torch.cuda.empty_cache()
        cleanup_dist_env_and_memory()
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
    except Exception:
        traceback.print_exc()
        raise


def check_mamba_state_equal(
    mamba_state_ref: dict, mamba_state_new: dict, keys_to_check: list[int]
):
    atol = 1e-2
    rtol = 1e-2
    for key in keys_to_check:
        assert key in mamba_state_new
        assert key in mamba_state_ref
        # mamba state new is a subset of mamba state ref
        for i, (ref, new) in enumerate(zip(mamba_state_ref[key], mamba_state_new[key])):
            if ref.device != new.device:
                new = new.to(ref.device)
            new = new[: ref.shape[0]]
            if not torch.allclose(ref, new, atol=atol, rtol=rtol):
                diff_mask = ~torch.isclose(ref, new, atol=atol, rtol=rtol)
                diff_idx = torch.nonzero(diff_mask)
                if diff_idx.shape[0] * 100 < ref.numel():
                    print(
                        f"[WARNING] found {diff_idx.shape[0] * 100 / ref.numel()}% of the elements are different"  # noqa: E501
                    )
                    continue
                raise ValueError(
                    f"Mamba state is not equal for key: {key} at index {i}"
                )
    return True


@dataclass
class TestConfig:
    num_prompt_tokens: int
    num_generated_tokens: int
    num_accepted_tokens: int
    step_actions: list[StepAction]


def apply_patch(monkeypatch: pytest.MonkeyPatch):
    monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")

    fake_sample_fn = get_fake_sample_fn()
    monkeypatch.setattr(GPUModelRunner, "_sample", fake_sample_fn)

    fake_propose_draft_token_ids_fn = get_fake_propose_draft_token_ids_fn()
    monkeypatch.setattr(
        GPUModelRunner, "propose_draft_token_ids", fake_propose_draft_token_ids_fn
    )

    fake_execute_model_fn = get_fake_execute_model_fn(GPUModelRunner.execute_model)
    monkeypatch.setattr(GPUModelRunner, "execute_model", fake_execute_model_fn)

    fake_step_action_fn = get_fake_step_action_fn(InprocClient.get_output)
    monkeypatch.setattr(InprocClient, "get_output", fake_step_action_fn)

    fake_allocate_slots_fn = get_fake_allocate_slots_fn(KVCacheManager.allocate_slots)
    monkeypatch.setattr(KVCacheManager, "allocate_slots", fake_allocate_slots_fn)

    fake_preprocess_mamba_fn, fake_post_process_mamba_fn, fake_copy_fn = (
        get_fake_process_mamba_fn(
            mamba_utils.preprocess_mamba,
            mamba_utils.postprocess_mamba,
            mamba_utils.do_mamba_copy_block,
        )
    )
    monkeypatch.setattr(mamba_utils, "preprocess_mamba", fake_preprocess_mamba_fn)
    monkeypatch.setattr(mamba_utils, "postprocess_mamba", fake_post_process_mamba_fn)
    monkeypatch.setattr(mamba_utils, "do_mamba_copy_block", fake_copy_fn)


517
@create_new_process_for_each_test()
518
519
520
521
522
523
524
525
526
527
528
529
530
def test_mamba_prefix_cache(monkeypatch: pytest.MonkeyPatch):
    run_ref_mamba_state_in_subprocess()
    apply_patch(monkeypatch)
    prompt_dataset = datasets.load_dataset("heheda/a_long_article")
    full_prompt = prompt_dataset["train"][0]["text"]
    tests = {
        "accept_1": TestConfig(
            num_prompt_tokens=554,
            num_generated_tokens=20,
            num_accepted_tokens=1,
            step_actions=[
                StepAction(0, 554, [1, 1, 1, 1], (-1, -1), (-1, -1)),
                StepAction(554, 4, [], (-1, -1), (-1, -1)),
531
                StepAction(555, 4, [1, 1, 1, 1, 1], (-1, -1), (-1, -1)),
532
                StepAction(556, 4, [], (-1, -1), (-1, -1)),
533
                StepAction(557, 4, [], (0, 1), (-1, -1)),
534
535
536
537
538
539
540
541
542
543
544
545
546
547
                StepAction(558, 4, [], (-1, -1), (-1, -1)),
                StepAction(559, 4, [], (-1, -1), (1, 0)),
                StepAction(560, 4, [], (-1, -1), (-1, -1)),
                StepAction(561, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
            ],
        ),
        # test case 2.1: no hit, accept 2 tokens
        "accept_2_1": TestConfig(
            num_prompt_tokens=554,
            num_generated_tokens=20,
            num_accepted_tokens=2,
            step_actions=[
                StepAction(0, 554, [1, 1, 1, 1], (-1, -1), (-1, -1)),
                StepAction(554, 4, [], (-1, -1), (-1, -1)),
548
549
                StepAction(556, 4, [1, 1, 1, 1, 1], (-1, -1), (-1, -1)),
                StepAction(558, 4, [], (1, 1), (2, 0)),
550
551
552
553
554
555
556
557
558
559
560
561
562
563
                StepAction(560, 4, [], (-1, -1), (-1, -1)),
                StepAction(562, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
            ],
        ),
        # test case 2.2: no hit, accept 2 tokens
        "accept_2_2": TestConfig(
            num_prompt_tokens=555,
            num_generated_tokens=20,
            num_accepted_tokens=2,
            step_actions=[
                StepAction(0, 555, [1, 1, 1, 1], (-1, -1), (-1, -1)),
                StepAction(555, 4, [], (-1, -1), (-1, -1)),
                StepAction(557, 4, [1, 1, 1, 1, 1], (1, 1), (-1, -1)),
                StepAction(559, 4, [], (-1, -1), (1, 0)),
564
565
                StepAction(561, 4, [], (-1, -1), (-1, -1)),
                StepAction(563, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
566
567
568
569
570
571
572
573
574
            ],
        ),
        "accept_3_1": TestConfig(
            num_prompt_tokens=553,
            num_generated_tokens=20,
            num_accepted_tokens=3,
            step_actions=[
                StepAction(0, 553, [1, 1, 1, 1], (-1, -1), (-1, -1)),
                StepAction(553, 4, [], (-1, -1), (-1, -1)),
575
576
577
578
                StepAction(556, 4, [1, 1, 1, 1, 1], (-1, -1), (-1, -1)),
                StepAction(559, 4, [], (2, 1), (1, 0)),
                StepAction(562, 4, [], (-1, -1), (-1, -1)),
                StepAction(565, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
            ],
        ),
        "accept_3_2": TestConfig(
            num_prompt_tokens=554,
            num_generated_tokens=20,
            num_accepted_tokens=3,
            step_actions=[
                StepAction(0, 554, [1, 1, 1, 1], (-1, -1), (-1, -1)),
                StepAction(554, 4, [], (-1, -1), (-1, -1)),
                StepAction(557, 4, [1, 1, 1, 1, 1], (2, 1), (3, 0)),
                StepAction(560, 4, [], (-1, -1), (-1, -1)),
                StepAction(563, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
            ],
        ),
        "accept_3_3": TestConfig(
            num_prompt_tokens=555,
            num_generated_tokens=20,
            num_accepted_tokens=3,
            step_actions=[
                StepAction(0, 555, [1, 1, 1, 1], (-1, -1), (-1, -1)),
                StepAction(555, 4, [], (-1, -1), (-1, -1)),
                StepAction(558, 4, [1, 1, 1, 1, 1], (2, 1), (2, 0)),
601
602
                StepAction(561, 4, [], (-1, -1), (-1, -1)),
                StepAction(564, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
603
604
605
606
607
608
609
610
611
612
            ],
        ),
        "accept_4_1": TestConfig(
            num_prompt_tokens=553,
            num_generated_tokens=20,
            num_accepted_tokens=4,
            step_actions=[
                StepAction(0, 553, [1, 1, 1, 1], (-1, -1), (-1, -1)),
                StepAction(553, 4, [], (-1, -1), (-1, -1)),
                StepAction(557, 4, [1, 1, 1, 1, 1], (3, 1), (3, 0)),
613
614
                StepAction(561, 4, [], (-1, -1), (-1, -1)),
                StepAction(565, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
615
616
617
618
619
620
621
622
623
624
            ],
        ),
        "accept_4_2": TestConfig(
            num_prompt_tokens=554,
            num_generated_tokens=25,
            num_accepted_tokens=4,
            step_actions=[
                StepAction(0, 554, [1, 1, 1, 1], (-1, -1), (-1, -1)),
                StepAction(554, 4, [], (-1, -1), (-1, -1)),
                StepAction(558, 4, [1, 1, 1, 1, 1], (3, 1), (2, 0)),
625
626
                StepAction(562, 4, [], (-1, -1), (-1, -1)),
                StepAction(566, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
627
628
629
630
631
632
633
634
635
636
            ],
        ),
        "accept_4_3": TestConfig(
            num_prompt_tokens=555,
            num_generated_tokens=25,
            num_accepted_tokens=4,
            step_actions=[
                StepAction(0, 555, [1, 1, 1, 1], (-1, -1), (-1, -1)),
                StepAction(555, 4, [], (-1, -1), (-1, -1)),
                StepAction(559, 4, [1, 1, 1, 1, 1], (3, 1), (1, 0)),
637
638
                StepAction(563, 4, [], (-1, -1), (-1, -1)),
                StepAction(567, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
            ],
        ),
        "accept_4_4": TestConfig(
            num_prompt_tokens=556,
            num_generated_tokens=25,
            num_accepted_tokens=4,
            step_actions=[
                StepAction(0, 556, [1, 1, 1, 1], (-1, -1), (-1, -1)),
                StepAction(556, 4, [], (-1, -1), (3, 0)),
                StepAction(560, 4, [1, 1, 1, 1, 1], (0, 1), (-1, -1)),
                StepAction(564, 4, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
            ],
        ),
        "prompt_block_size": TestConfig(
            num_prompt_tokens=560,
            num_generated_tokens=10,
            num_accepted_tokens=4,
            step_actions=[
                StepAction(0, 560, [1, 1, 1, 1], (-1, -1), (-1, -1)),
                StepAction(560, 4, [1, 1, 1, 1, 1], (0, 1), (-1, -1)),
            ],
        ),
        "prompt_2_block_size": TestConfig(
            num_prompt_tokens=560 * 2,
            num_generated_tokens=10,
            num_accepted_tokens=4,
            step_actions=[
                StepAction(0, 560, [1, 1, 1, 1], (-1, -1), (-1, -1)),
                StepAction(560, 560, [1, 1, 1, 1, 1], (0, 1), (-1, -1)),
                StepAction(560 * 2, 4, [0, 1, 1, 1, 1, 1], (1, 2), (-1, -1)),
            ],
        ),
        "prompt_2_block_size_10": TestConfig(
            num_prompt_tokens=560 * 2 + 10,
            num_generated_tokens=10,
            num_accepted_tokens=4,
            step_actions=[
                StepAction(0, 560, [1, 1, 1, 1], (-1, -1), (-1, -1)),
                StepAction(560, 570, [1, 0, 1, 1, 1, 1], (0, 2), (-1, -1)),
                StepAction(560 * 2 + 10, 4, [0, 0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
            ],
        ),
        "prompt_3_block_size": TestConfig(
            num_prompt_tokens=560 * 3,
            num_generated_tokens=10,
            num_accepted_tokens=4,
            step_actions=[
                StepAction(0, 560 * 2, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
                StepAction(560 * 2, 560, [0, 1, 1, 1, 1, 1], (1, 2), (-1, -1)),
                StepAction(560 * 3, 4, [0, 0, 1, 1, 1, 1, 1], (2, 3), (-1, -1)),
            ],
        ),
        "prompt_3_block_size_10": TestConfig(
            num_prompt_tokens=560 * 3 + 10,
            num_generated_tokens=10,
            num_accepted_tokens=4,
            step_actions=[
                StepAction(0, 560 * 2, [0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
                StepAction(560 * 2, 570, [0, 1, 0, 1, 1, 1, 1], (1, 3), (-1, -1)),
                StepAction(560 * 3 + 10, 4, [0, 0, 0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
            ],
        ),
        "prompt_10_block_size": TestConfig(
            num_prompt_tokens=560 * 10,
            num_generated_tokens=10,
            num_accepted_tokens=4,
            step_actions=[
                StepAction(0, 560 * 5, [0, 0, 0, 0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
                StepAction(
                    560 * 5,
                    560 * 4,
                    [0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1],
                    (4, 8),
                    (-1, -1),
                ),
                StepAction(
                    560 * 9,
                    560,
                    [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1],
                    (8, 9),
                    (-1, -1),
                ),
                StepAction(
                    560 * 10,
                    4,
                    [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1],
                    (9, 10),
                    (-1, -1),
                ),
            ],
        ),
        "prompt_10_block_size_10": TestConfig(
            num_prompt_tokens=560 * 10 + 10,
            num_generated_tokens=10,
            num_accepted_tokens=4,
            step_actions=[
                StepAction(0, 560 * 5, [0, 0, 0, 0, 1, 1, 1, 1], (-1, -1), (-1, -1)),
                StepAction(
                    560 * 5,
                    560 * 4,
                    [0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1],
                    (4, 8),
                    (-1, -1),
                ),
                StepAction(
                    560 * 9,
                    560 + 10,
                    [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1],
                    (8, 10),
                    (-1, -1),
                ),
            ],
        ),
    }

    engine = LLM(
        model=MODEL,
        enable_prefix_caching=True,
        block_size=BLOCK_SIZE,
        mamba_cache_mode="align",
        speculative_config={
            "method": "qwen3_next_mtp",
            "num_speculative_tokens": num_speculative_tokens,
        },
        max_num_batched_tokens=3072,
        hf_overrides={"num_hidden_layers": NUM_HIDDEN_LAYERS},
        seed=42,
    )
    global prompt_token_ids
    prompt_token_ids = engine.get_tokenizer().encode(full_prompt)
    print(f"Token IDs length: {len(prompt_token_ids)}")
    for test_case_name, test_config in tests.items():
        print(f"Running test case: {test_case_name}")
        num_generated_tokens = test_config.num_generated_tokens
        num_prompt_tokens = test_config.num_prompt_tokens
        global num_accepted_tokens
        num_accepted_tokens = test_config.num_accepted_tokens
        sampling_params = SamplingParams(
            temperature=0.0, max_tokens=num_generated_tokens
        )
        global cur_step_action_idx
        cur_step_action_idx = 0
        for step_action_prev, step_action_next in zip(
            test_config.step_actions[:-1], test_config.step_actions[1:]
        ):
            if (
                step_action_next.kv_cache_block_ids is not None
                and len(step_action_next.kv_cache_block_ids) == 0
            ):
                prev_block_ids = step_action_prev.kv_cache_block_ids
                if prev_block_ids is not None:
                    step_action_next.kv_cache_block_ids = prev_block_ids.copy()
        global step_actions
        step_actions = test_config.step_actions
        _ = engine.generate(
            [TokensPrompt(prompt_token_ids=prompt_token_ids[:num_prompt_tokens])],
            sampling_params,
        )
        assert engine.llm_engine.engine_core.engine_core.scheduler.reset_prefix_cache()
        print(f"End test case: {test_case_name}")
        keys_to_check = [
            (action.postprocess_copy_idx[1] + 1) * BLOCK_SIZE
            for action in test_config.step_actions
            if action.postprocess_copy_idx and action.postprocess_copy_idx[0] != -1
        ]
        mamba_state_ref = torch.load("mamba_kv_cache_dict_ref.pth")
        check_mamba_state_equal(mamba_state_ref, mamba_kv_cache_dict, keys_to_check)
        mamba_kv_cache_dict.clear()
807
808
809
    del engine
    torch.cuda.empty_cache()
    cleanup_dist_env_and_memory()