test_hybrid.py 24.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from collections.abc import Callable
5

Mor Zusman's avatar
Mor Zusman committed
6
7
import pytest

8
from tests.models.registry import HF_EXAMPLE_MODELS
9
from tests.utils import multi_gpu_test
10
from vllm.engine.arg_utils import EngineArgs
11
from vllm.sampling_params import SamplingParams
12

13
from ...utils import check_logprobs_close, check_outputs_equal
14

15
16
17
# Mark all tests as hybrid
pytestmark = pytest.mark.hybrid_model

18
19
20
21
# NOTE: The first model in each list is taken as the primary model,
# meaning that it will be used in all tests in this file
# The rest of the models will only be tested by test_models

22
23
APC_MULTIPLY_BY = 300

24
25
26
SSM_MODELS = [
    "state-spaces/mamba-130m-hf",
    "tiiuae/falcon-mamba-tiny-dev",
27
28
    # mamba2-codestral in transformers is broken pending:
    # https://github.com/huggingface/transformers/pull/40861
29
    # "yujiepan/mamba2-codestral-v0.1-tiny-random",
30
]
31

32
33
HYBRID_MODELS = [
    "ai21labs/Jamba-tiny-dev",
34
    "pfnet/plamo-2-1b",
35
    "Zyphra/Zamba2-1.2B-instruct",
36
    "hmellor/tiny-random-BambaForCausalLM",
37
38
    "ibm-granite/granite-4.0-tiny-preview",
    "tiiuae/Falcon-H1-0.5B-Base",
39
    "LiquidAI/LFM2-1.2B",
40
    "tiny-random/qwen3-next-moe",
Chen Zhang's avatar
Chen Zhang committed
41
42
]

43
FULL_CUDA_GRAPH_MODELS = [
44
    "ai21labs/Jamba-tiny-dev",
45
    "pfnet/plamo-2-1b",
46
    "Zyphra/Zamba2-1.2B-instruct",
47
48
]

49
50
51
52
53
FP32_STATE_MODELS = [
    "state-spaces/mamba-130m-hf",
    "Zyphra/Zamba2-1.2B-instruct",
]

54
55
56
# Avoid OOM
MAX_NUM_SEQS = 4

Mor Zusman's avatar
Mor Zusman committed
57

58
59
60
@pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS)
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
Mor Zusman's avatar
Mor Zusman committed
61
62
63
64
def test_models(
    hf_runner,
    vllm_runner,
    example_prompts,
Chen Zhang's avatar
Chen Zhang committed
65
    monkeypatch,
Mor Zusman's avatar
Mor Zusman committed
66
67
    model: str,
    max_tokens: int,
68
    num_logprobs: int,
Mor Zusman's avatar
Mor Zusman committed
69
) -> None:
70
71
72
    try:
        model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
        model_info.check_available_online(on_fail="skip")
73
        model_info.check_transformers_version(on_fail="skip")
74
    except ValueError:
75
        pass
76

77
    with hf_runner(model) as hf_model:
78
        hf_outputs = hf_model.generate_greedy_logprobs_limit(
79
80
            example_prompts, max_tokens, num_logprobs
        )
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
81

82
83
    with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
        vllm_outputs = vllm_model.generate_greedy_logprobs(
84
85
            example_prompts, max_tokens, num_logprobs
        )
Chen Zhang's avatar
Chen Zhang committed
86

87
88
89
90
91
92
    check_logprobs_close(
        outputs_0_lst=hf_outputs,
        outputs_1_lst=vllm_outputs,
        name_0="hf",
        name_1="vllm",
    )
Mor Zusman's avatar
Mor Zusman committed
93
94


95
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
96
97
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
98
99
100
101
102
def test_batching(
    vllm_runner,
    example_prompts,
    model: str,
    max_tokens: int,
103
    num_logprobs: int,
104
) -> None:
105
106
107
108
109
110
111
    try:
        model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
        model_info.check_available_online(on_fail="skip")
        model_info.check_transformers_version(on_fail="skip")
    except ValueError:
        pass

112
    for_loop_outputs = []
113
    with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
114
        for prompt in example_prompts:
115
116
117
            (single_output,) = vllm_model.generate_greedy_logprobs(
                [prompt], max_tokens, num_logprobs
            )
118
            for_loop_outputs.append(single_output)
119

120
        batched_outputs = vllm_model.generate_greedy_logprobs(
121
122
            example_prompts, max_tokens, num_logprobs
        )
123

124
    check_logprobs_close(
125
126
127
128
129
130
131
        outputs_0_lst=for_loop_outputs,
        outputs_1_lst=batched_outputs,
        name_0="for_loop_vllm",
        name_1="batched_vllm",
    )


132
133
134
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
@pytest.mark.parametrize("max_tokens", [10])
def test_chunked_prefill_with_parallel_sampling(
135
136
137
138
139
    vllm_runner,
    example_prompts,
    model: str,
    max_tokens: int,
) -> None:
140
    """
141
142
    Tests chunked prefill in conjunction with n > 1.

143
144
145
146
147
148
149
    In this case, prefill is populated with decoding tokens and
    we test that it doesn't fail.

    This test might fail if cache is not allocated correctly for n > 1
    decoding steps inside a chunked prefill forward pass
    (where we have both prefill and decode together)
    """
150
    sampling_params = SamplingParams(n=3, temperature=1, seed=0, max_tokens=max_tokens)
151
    with vllm_runner(
152
153
154
155
156
        model,
        enable_chunked_prefill=True,
        # forces prefill chunks with decoding
        max_num_batched_tokens=MAX_NUM_SEQS * 3,
        max_num_seqs=MAX_NUM_SEQS,
157
158
    ) as vllm_model:
        vllm_model.generate(example_prompts, sampling_params)
159
160


161
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
162
163
164
165
166
167
168
@pytest.mark.parametrize("max_tokens", [20])
def test_mamba_cache_cg_padding(
    vllm_runner,
    example_prompts,
    model: str,
    max_tokens: int,
) -> None:
169
170
171
172
173
    """
    This test is for verifying that mamba cache is padded to CG captured
    batch size. If it's not, a torch RuntimeError will be raised because
    tensor dimensions aren't compatible.
    """
174
175
    vllm_config = EngineArgs(model=model, trust_remote_code=True).create_engine_config()
    while len(example_prompts) == vllm_config.pad_for_cudagraph(len(example_prompts)):
176
177
178
        example_prompts.append(example_prompts[0])

    try:
179
        with vllm_runner(model) as vllm_model:
180
181
182
183
184
            vllm_model.generate_greedy(example_prompts, max_tokens)
    except RuntimeError:
        pytest.fail(
            "Couldn't run batch size which is not equal to a Cuda Graph "
            "captured batch size. "
185
186
            "Could be related to mamba cache not padded correctly"
        )
187
188


189
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
190
191
192
def test_fail_upon_inc_requests_and_finished_requests_lt_available_blocks(
    vllm_runner,
    example_prompts,
193
    model: str,
194
) -> None:
195
196
197
198
199
200
    """
    This test is for verifying that the hybrid inner state management doesn't
    collapse in case where the number of incoming requests and
    finished_requests_ids is larger than the maximum mamba block capacity.

    This could generally happen due to the fact that hybrid does support
201
    statelessness mechanism where it can clean up new incoming requests in
202
203
    a single step.
    """
204
    try:
205
        with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
206
207
            vllm_model.generate_greedy([example_prompts[0]] * 100, 10)
    except ValueError:
208
209
210
211
        pytest.fail(
            "Hybrid inner state wasn't cleaned up properly between"
            "steps finished requests registered unnecessarily "
        )
212
213


214
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
Mor Zusman's avatar
Mor Zusman committed
215
216
217
def test_state_cleanup(
    vllm_runner,
    example_prompts,
218
    model: str,
Mor Zusman's avatar
Mor Zusman committed
219
) -> None:
220
    """
221
222
    This test is for verifying that the Hybrid state is cleaned up between
    steps.
223

224
    If it's not cleaned, an error would be expected.
225
    """
Mor Zusman's avatar
Mor Zusman committed
226
    try:
227
        with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
Mor Zusman's avatar
Mor Zusman committed
228
229
230
            for _ in range(10):
                vllm_model.generate_greedy([example_prompts[0]] * 100, 1)
    except ValueError:
231
232
233
234
        pytest.fail(
            "Hybrid inner state wasn't cleaned up between states, "
            "could be related to finished_requests_ids"
        )
Mor Zusman's avatar
Mor Zusman committed
235
236


237
@multi_gpu_test(num_gpus=2)
238
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
239
@pytest.mark.parametrize("max_tokens", [64])
240
241
@pytest.mark.parametrize("num_logprobs", [5])
def test_distributed_correctness(
242
243
244
245
    vllm_runner,
    example_prompts,
    model: str,
    max_tokens: int,
246
    num_logprobs: int,
247
) -> None:
248
249
250
    with vllm_runner(
        model, tensor_parallel_size=1, max_num_seqs=MAX_NUM_SEQS
    ) as vllm_model:
251
        vllm_outputs_tp_1 = vllm_model.generate_greedy_logprobs(
252
253
            example_prompts, max_tokens, num_logprobs
        )
254

255
256
257
    with vllm_runner(
        model, tensor_parallel_size=2, max_num_seqs=MAX_NUM_SEQS
    ) as vllm_model:
258
        vllm_outputs_tp_2 = vllm_model.generate_greedy_logprobs(
259
260
            example_prompts, max_tokens, num_logprobs
        )
261

262
    check_logprobs_close(
263
264
265
266
267
        outputs_0_lst=vllm_outputs_tp_1,
        outputs_1_lst=vllm_outputs_tp_2,
        name_0="vllm_tp_1",
        name_1="vllm_tp_2",
    )
268
269


270
@pytest.mark.parametrize("model", FULL_CUDA_GRAPH_MODELS)
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
def test_full_cuda_graph(
    hf_runner,
    vllm_runner,
    example_prompts,
    monkeypatch,
    model: str,
    max_tokens: int,
    num_logprobs: int,
) -> None:
    try:
        model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
        model_info.check_available_online(on_fail="skip")
        model_info.check_transformers_version(on_fail="skip")
    except ValueError:
        pass

    with hf_runner(model) as hf_model:
290
        hf_outputs = hf_model.generate_greedy_logprobs_limit(
291
292
            example_prompts, max_tokens, num_logprobs
        )
293

294
    with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
295
        vllm_outputs = vllm_model.generate_greedy_logprobs(
296
297
            example_prompts, max_tokens, num_logprobs
        )
298
299

    check_logprobs_close(
300
        outputs_0_lst=hf_outputs,
301
        outputs_1_lst=vllm_outputs,
302
        name_0="hf",
303
        name_1="vllm",
304
    )
305
306


307
@pytest.mark.parametrize("model", FP32_STATE_MODELS)
308
309
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
310
311
312
@pytest.mark.parametrize(
    "cache_dtype_param", ["mamba_ssm_cache_dtype", "mamba_cache_dtype"]
)
313
def test_fp32_cache_state(
314
315
316
317
318
319
320
    hf_runner,
    vllm_runner,
    example_prompts,
    monkeypatch,
    model: str,
    max_tokens: int,
    num_logprobs: int,
321
    cache_dtype_param: str,
322
323
324
325
326
327
328
329
330
) -> None:
    try:
        model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
        model_info.check_available_online(on_fail="skip")
        model_info.check_transformers_version(on_fail="skip")
    except ValueError:
        pass

    with hf_runner(model) as hf_model:
331
        hf_outputs = hf_model.generate_greedy_logprobs_limit(
332
333
            example_prompts, max_tokens, num_logprobs
        )
334

335
336
337
    with vllm_runner(
        model, max_num_seqs=MAX_NUM_SEQS, **{cache_dtype_param: "float32"}
    ) as vllm_model:
338
        vllm_outputs = vllm_model.generate_greedy_logprobs(
339
340
            example_prompts, max_tokens, num_logprobs
        )
341

342
    check_logprobs_close(
343
        outputs_0_lst=hf_outputs,
344
        outputs_1_lst=vllm_outputs,
345
        name_0="hf",
346
        name_1="vllm",
347
    )
348
349
350


# Helper functions for the APC tests
351
352
353
354
355
def _get_vllm_runner_params(
    model: str,
    max_model_len: int,
    tensor_parallel_size: int = 1,
):
356
    return {
357
        "model_name": model,
358
        "enable_chunked_prefill": True,
359
360
361
362
        "enable_prefix_caching": False,
        "max_model_len": max_model_len,
        "tensor_parallel_size": tensor_parallel_size,
        "gpu_memory_utilization": 0.4,
363
364
365
    }


366
367
368
369
370
371
372
373
374
def _get_vLLM_output(
    vllm_runner,
    kwargs,
    prompts,
    max_tokens,
    num_logprobs,
    num_repetitions=1,
    vllm_model=None,
):
375
376
377
378
379
380
381
382
    outs = []
    if vllm_model is None:
        vllm_model = vllm_runner(**kwargs)
    for _ in range(num_repetitions):
        if num_logprobs < 0:
            vllm_output = vllm_model.generate_greedy(prompts, max_tokens)
        else:
            vllm_output = vllm_model.generate_greedy_logprobs(
383
384
                prompts, max_tokens, num_logprobs
            )
385
386
387
388
389
        outs.append(vllm_output)

    return outs, vllm_model


390
@pytest.mark.parametrize("model", [HYBRID_MODELS[0], HYBRID_MODELS[3]])
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("n_repetitions", [2])
# If num_logprobs is set to -1, then the stringent version
# of the test is executed using `check_outputs_equal`
# instead of `check_logprobs_close`
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("tensor_parallel_size", [1])
def test_apc_single_prompt(
    hf_runner,
    vllm_runner,
    example_prompts,
    monkeypatch,
    model: str,
    max_tokens: int,
    n_repetitions: int,
    num_logprobs: int,
    tensor_parallel_size: int,
) -> None:
    try:
        model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
        model_info.check_available_online(on_fail="skip")
        model_info.check_transformers_version(on_fail="skip")
    except ValueError:
        pass

416
417
418
    compare_operator: Callable = (
        check_logprobs_close if num_logprobs > 0 else check_outputs_equal  # type: ignore
    )
419
420

    # Sample prompts.
421
    generated_prompts = [APC_MULTIPLY_BY * example_prompts[0]]
422

423
    max_model_len = max(len(prompt) + max_tokens for prompt in generated_prompts)
424
    vllm_runner_kwargs = _get_vllm_runner_params(
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
        model, max_model_len, tensor_parallel_size=tensor_parallel_size
    )
    vllm_runner_kwargs["mamba_ssm_cache_dtype"] = "float32"
    vllm_outputs_no_cache, _ = _get_vLLM_output(
        vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, num_logprobs
    )

    vllm_runner_kwargs["enable_prefix_caching"] = True
    vllm_outputs_cache_rep, _ = _get_vLLM_output(
        vllm_runner,
        vllm_runner_kwargs,
        generated_prompts,
        max_tokens,
        num_logprobs,
        n_repetitions,
    )
441
442
443
444
445
446
447
448
449
450
451
452
453

    for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep):
        # In the first repetition, the caches are filled
        # In the second repetition, these caches are reused

        compare_operator(
            outputs_0_lst=vllm_outputs_no_cache[0],
            outputs_1_lst=vllm_outputs_cache_itn,
            name_0="vllm_no_cache",
            name_1=f"vllm_cache_it_{r_idx + 1}",
        )


454
@pytest.mark.parametrize("model", [HYBRID_MODELS[0], HYBRID_MODELS[3]])
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("n_repetitions", [2])
# If num_logprobs is set to -1, then the stringent version
# of the test is executed using `check_outputs_equal`
# instead of `check_logprobs_close`
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("tensor_parallel_size", [1])
def test_apc_single_prompt_block_align_alignment(
    hf_runner,
    vllm_runner,
    example_prompts,
    monkeypatch,
    model: str,
    max_tokens: int,
    n_repetitions: int,
    num_logprobs: int,
    tensor_parallel_size: int,
) -> None:
    try:
        model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
        model_info.check_available_online(on_fail="skip")
        model_info.check_transformers_version(on_fail="skip")
    except ValueError:
        pass

480
481
482
    compare_operator: Callable = (
        check_logprobs_close if num_logprobs > 0 else check_outputs_equal  # type: ignore
    )
483
484

    # Sample prompts. This custom prompt is used, as it causes the most issues
485
    generated_prompts = ["The president of the United States is " * APC_MULTIPLY_BY]
486

487
    max_model_len = max(len(prompt) + max_tokens for prompt in generated_prompts)
488
    vllm_runner_kwargs = _get_vllm_runner_params(
489
490
491
        model, max_model_len, tensor_parallel_size=tensor_parallel_size
    )
    vllm_runner_kwargs["mamba_ssm_cache_dtype"] = "float32"
492

493
494
495
    vllm_outputs_no_cache, _ = _get_vLLM_output(
        vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, num_logprobs
    )
496

497
    vllm_runner_kwargs["enable_prefix_caching"] = True
498
499
    with vllm_runner(**vllm_runner_kwargs) as vllm_model:
        # Retrieve the default mamba state block size
500
        mamba_block_size = vllm_model.llm.llm_engine.cache_config.mamba_block_size
501
502
503
504
505
506
507

    # In case the hybrid model does not have the
    # "mamba_block_size" assume a fixed constant
    if mamba_block_size is None:
        mamba_block_size = 512

    mamba_block_size_multiplier = 10
508
509
510
511
512
513
514
515
516
517
518
519
    for offsets in [-3, 3, mamba_block_size // 4 + 3, mamba_block_size // 2 - 3]:
        vllm_runner_kwargs["max_num_batched_tokens"] = (
            mamba_block_size_multiplier * mamba_block_size - offsets
        )
        vllm_outputs_cache_rep, _ = _get_vLLM_output(
            vllm_runner,
            vllm_runner_kwargs,
            generated_prompts,
            max_tokens,
            num_logprobs,
            n_repetitions,
        )
520
521
522
523
524
525
526
527
528
529
530
531
532
533

        # Check alignment of the output logits when using APC
        for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep):
            # In the first repetition, the caches are filled
            # In the second repetition, these caches are reused

            compare_operator(
                outputs_0_lst=vllm_outputs_no_cache[0],
                outputs_1_lst=vllm_outputs_cache_itn,
                name_0="vllm_no_cache",
                name_1=f"vllm_cache_it_{r_idx + 1}",
            )


534
@pytest.mark.parametrize("model", [HYBRID_MODELS[0], HYBRID_MODELS[3]])
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("n_repetitions", [2])
# If num_logprobs is set to -1, then the stringent version
# of the test is executed using `check_outputs_equal`
# instead of `check_logprobs_close`
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("tensor_parallel_size", [1])
def test_apc_multiple_prompts_all_cached_outputs(
    hf_runner,
    vllm_runner,
    example_prompts,
    monkeypatch,
    model: str,
    max_tokens: int,
    n_repetitions: int,
    num_logprobs: int,
    tensor_parallel_size: int,
) -> None:
    try:
        model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
        model_info.check_available_online(on_fail="skip")
        model_info.check_transformers_version(on_fail="skip")
    except ValueError:
        pass

560
561
562
    compare_operator: Callable = (
        check_logprobs_close if num_logprobs > 0 else check_outputs_equal  # type: ignore
    )
563
564

    # Sample prompts.
565
    generated_prompts = [APC_MULTIPLY_BY * prompt for prompt in example_prompts]
566

567
    max_model_len = max(len(prompt) + max_tokens for prompt in generated_prompts)
568
    vllm_runner_kwargs = _get_vllm_runner_params(
569
570
571
        model, max_model_len, tensor_parallel_size=tensor_parallel_size
    )
    vllm_runner_kwargs["mamba_ssm_cache_dtype"] = "float32"
572

573
574
575
    vllm_outputs_no_cache, _ = _get_vLLM_output(
        vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, num_logprobs
    )
576

577
578
579
580
581
582
583
584
585
    vllm_runner_kwargs["enable_prefix_caching"] = True
    vllm_outputs_cache_rep, _ = _get_vLLM_output(
        vllm_runner,
        vllm_runner_kwargs,
        generated_prompts,
        max_tokens,
        num_logprobs,
        n_repetitions,
    )
586
587
588
589
590
591
592
593
594
595
596
597
598

    for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep):
        # In the first repetition, the caches are filled
        # In the second repetition, these caches are reused

        compare_operator(
            outputs_0_lst=vllm_outputs_no_cache[0],
            outputs_1_lst=vllm_outputs_cache_itn,
            name_0="vllm_no_cache",
            name_1=f"vllm_cache_it_{r_idx + 1}",
        )


599
@pytest.mark.parametrize("model", [HYBRID_MODELS[0], HYBRID_MODELS[3]])
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("n_repetitions", [2])
# If num_logprobs is set to -1, then the stringent version
# of the test is executed using `check_outputs_equal`
# instead of `check_logprobs_close`
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("tensor_parallel_size", [1])
def test_apc_multiple_prompts_block_align_alignment(
    hf_runner,
    vllm_runner,
    example_prompts,
    monkeypatch,
    model: str,
    max_tokens: int,
    n_repetitions: int,
    num_logprobs: int,
    tensor_parallel_size: int,
) -> None:
    try:
        model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
        model_info.check_available_online(on_fail="skip")
        model_info.check_transformers_version(on_fail="skip")
    except ValueError:
        pass

625
626
627
    compare_operator: Callable = (
        check_logprobs_close if num_logprobs > 0 else check_outputs_equal  # type: ignore
    )
628
629
630
631

    # Sample prompts. This custom prompt is used, as it causes the most issues
    prompt_text = "The president of the United States is "
    prompt_offsets = [0, 3, 7, 13, 17, 22, 25, 31]
632
633
634
    generated_prompts = [
        prompt_text[offset:] * APC_MULTIPLY_BY for offset in prompt_offsets
    ]
635
636
637
638
639
640
641
642
643
644
645
646

    max_model_len = max(len(prompt) + max_tokens for prompt in generated_prompts)
    vllm_runner_kwargs = _get_vllm_runner_params(
        model, max_model_len, tensor_parallel_size
    )
    vllm_runner_kwargs["mamba_ssm_cache_dtype"] = "float32"

    vllm_outputs_no_cache, _ = _get_vLLM_output(
        vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, num_logprobs
    )

    vllm_runner_kwargs["enable_prefix_caching"] = True
647
648
    with vllm_runner(**vllm_runner_kwargs) as vllm_model:
        # Retrieve the default mamba state block size
649
        mamba_block_size = vllm_model.llm.llm_engine.cache_config.mamba_block_size
650
651
652
653
654
655
656

    # In case the hybrid model does not have the
    # "mamba_block_size" assume a fixed constant
    if mamba_block_size is None:
        mamba_block_size = 512

    mamba_block_size_multiplier = 10
657
658
659
660
661
662
663
664
665
666
667
668
    for offsets in [-3, 3, mamba_block_size // 4 + 3, mamba_block_size // 2 - 3]:
        vllm_runner_kwargs["max_num_batched_tokens"] = (
            mamba_block_size_multiplier * mamba_block_size - offsets
        )
        vllm_outputs_cache_rep, _ = _get_vLLM_output(
            vllm_runner,
            vllm_runner_kwargs,
            generated_prompts,
            max_tokens,
            num_logprobs,
            n_repetitions,
        )
669
670
671
672
673
674
675
676
677
678
679
680
681
682

        # Check alignment of the output logits when using APC
        for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep):
            # In the first repetition, the caches are filled
            # In the second repetition, these caches are reused

            compare_operator(
                outputs_0_lst=vllm_outputs_no_cache[0],
                outputs_1_lst=vllm_outputs_cache_itn,
                name_0="vllm_no_cache",
                name_1=f"vllm_cache_it_{r_idx + 1}",
            )


683
@pytest.mark.parametrize("model", [HYBRID_MODELS[0], HYBRID_MODELS[3]])
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("n_repetitions", [2])
# If num_logprobs is set to -1, then the stringent version
# of the test is executed using `check_outputs_equal`
# instead of `check_logprobs_close`
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("tensor_parallel_size", [1])
def test_apc_multiple_prompts_partial_cached_outputs(
    hf_runner,
    vllm_runner,
    example_prompts,
    monkeypatch,
    model: str,
    max_tokens: int,
    n_repetitions: int,
    num_logprobs: int,
    tensor_parallel_size: int,
) -> None:
    try:
        model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
        model_info.check_available_online(on_fail="skip")
        model_info.check_transformers_version(on_fail="skip")
    except ValueError:
        pass

709
710
711
    compare_operator: Callable = (
        check_logprobs_close if num_logprobs > 0 else check_outputs_equal  # type: ignore
    )
712
713

    # Sample prompts.
714
    generated_prompts = [APC_MULTIPLY_BY * prompt for prompt in example_prompts]
715

716
    max_model_len = max(len(prompt) + max_tokens for prompt in generated_prompts)
717
    vllm_runner_kwargs = _get_vllm_runner_params(
718
719
720
        model, max_model_len, tensor_parallel_size=tensor_parallel_size
    )
    vllm_runner_kwargs["mamba_ssm_cache_dtype"] = "float32"
721

722
723
724
    vllm_outputs_no_cache, _ = _get_vLLM_output(
        vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, num_logprobs
    )
725
726

    # Cache only part of all the prompts
727
    vllm_runner_kwargs["enable_prefix_caching"] = True
728
    vllm_outputs_partial_cache, vllm_model = _get_vLLM_output(
729
730
        vllm_runner, vllm_runner_kwargs, generated_prompts[:3], max_tokens, num_logprobs
    )
731
732
733
734
735
736
737
738

    compare_operator(
        outputs_0_lst=vllm_outputs_no_cache[0][:3],
        outputs_1_lst=vllm_outputs_partial_cache[0],
        name_0="vllm_no_cache",
        name_1="vllm_partial_cache",
    )

739
740
741
742
743
744
745
746
747
    vllm_outputs_cache_rep, _ = _get_vLLM_output(
        vllm_runner,
        vllm_runner_kwargs,
        generated_prompts,
        max_tokens,
        num_logprobs,
        n_repetitions,
        vllm_model=vllm_model,
    )
748
749
750
751
752
753
754
755
756
757
758

    for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep):
        # In the first repetition, the caches are filled
        # In the second repetition, these caches are reused

        compare_operator(
            outputs_0_lst=vllm_outputs_no_cache[0],
            outputs_1_lst=vllm_outputs_cache_itn,
            name_0="vllm_no_cache",
            name_1=f"vllm_cache_it_{r_idx + 1}",
        )