test_hybrid.py 24.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from collections.abc import Callable
5

Mor Zusman's avatar
Mor Zusman committed
6
7
import pytest

8
from tests.models.registry import HF_EXAMPLE_MODELS
9
from tests.utils import multi_gpu_test
10
from vllm.engine.arg_utils import EngineArgs
11
from vllm.platforms import current_platform
12
from vllm.sampling_params import SamplingParams
13
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
14

15
from ...utils import check_logprobs_close, check_outputs_equal
16

17
18
19
# Mark all tests as hybrid
pytestmark = pytest.mark.hybrid_model

20
21
22
23
# NOTE: The first model in each list is taken as the primary model,
# meaning that it will be used in all tests in this file
# The rest of the models will only be tested by test_models

24
25
APC_MULTIPLY_BY = 300

26
27
28
SSM_MODELS = [
    "state-spaces/mamba-130m-hf",
    "tiiuae/falcon-mamba-tiny-dev",
29
30
    # mamba2-codestral in transformers is broken pending:
    # https://github.com/huggingface/transformers/pull/40861
31
    # "yujiepan/mamba2-codestral-v0.1-tiny-random",
32
]
33

34
35
HYBRID_MODELS = [
    "ai21labs/Jamba-tiny-dev",
36
    "pfnet/plamo-2-1b",
37
    "Zyphra/Zamba2-1.2B-instruct",
38
    "hmellor/tiny-random-BambaForCausalLM",
39
40
    "ibm-granite/granite-4.0-tiny-preview",
    "tiiuae/Falcon-H1-0.5B-Base",
41
    "LiquidAI/LFM2-1.2B",
42
    "tiny-random/qwen3-next-moe",
Chen Zhang's avatar
Chen Zhang committed
43
44
]

45
FULL_CUDA_GRAPH_MODELS = [
46
    "ai21labs/Jamba-tiny-dev",
47
    "pfnet/plamo-2-1b",
48
    "Zyphra/Zamba2-1.2B-instruct",
49
50
]

51
52
53
54
55
FP32_STATE_MODELS = [
    "state-spaces/mamba-130m-hf",
    "Zyphra/Zamba2-1.2B-instruct",
]

56
57
58
# Avoid OOM
MAX_NUM_SEQS = 4

Mor Zusman's avatar
Mor Zusman committed
59

60
61
62
@pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS)
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
Mor Zusman's avatar
Mor Zusman committed
63
64
65
66
def test_models(
    hf_runner,
    vllm_runner,
    example_prompts,
Chen Zhang's avatar
Chen Zhang committed
67
    monkeypatch,
Mor Zusman's avatar
Mor Zusman committed
68
69
    model: str,
    max_tokens: int,
70
    num_logprobs: int,
Mor Zusman's avatar
Mor Zusman committed
71
) -> None:
72
73
74
    try:
        model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
        model_info.check_available_online(on_fail="skip")
75
        model_info.check_transformers_version(on_fail="skip")
76
    except ValueError:
77
        pass
78

79
    with hf_runner(model) as hf_model:
80
        hf_outputs = hf_model.generate_greedy_logprobs_limit(
81
82
            example_prompts, max_tokens, num_logprobs
        )
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
83

84
85
    with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
        vllm_outputs = vllm_model.generate_greedy_logprobs(
86
87
            example_prompts, max_tokens, num_logprobs
        )
Chen Zhang's avatar
Chen Zhang committed
88

89
90
91
92
93
94
    check_logprobs_close(
        outputs_0_lst=hf_outputs,
        outputs_1_lst=vllm_outputs,
        name_0="hf",
        name_1="vllm",
    )
Mor Zusman's avatar
Mor Zusman committed
95
96


97
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
98
99
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
100
101
102
103
104
def test_batching(
    vllm_runner,
    example_prompts,
    model: str,
    max_tokens: int,
105
    num_logprobs: int,
106
) -> None:
107
108
109
110
111
112
113
    try:
        model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
        model_info.check_available_online(on_fail="skip")
        model_info.check_transformers_version(on_fail="skip")
    except ValueError:
        pass

114
    for_loop_outputs = []
115
    with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
116
        for prompt in example_prompts:
117
118
119
            (single_output,) = vllm_model.generate_greedy_logprobs(
                [prompt], max_tokens, num_logprobs
            )
120
            for_loop_outputs.append(single_output)
121

122
        batched_outputs = vllm_model.generate_greedy_logprobs(
123
124
            example_prompts, max_tokens, num_logprobs
        )
125

126
    check_logprobs_close(
127
128
129
130
131
132
133
        outputs_0_lst=for_loop_outputs,
        outputs_1_lst=batched_outputs,
        name_0="for_loop_vllm",
        name_1="batched_vllm",
    )


134
135
136
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
@pytest.mark.parametrize("max_tokens", [10])
def test_chunked_prefill_with_parallel_sampling(
137
138
139
140
141
    vllm_runner,
    example_prompts,
    model: str,
    max_tokens: int,
) -> None:
142
    """
143
144
    Tests chunked prefill in conjunction with n > 1.

145
146
147
148
149
150
151
    In this case, prefill is populated with decoding tokens and
    we test that it doesn't fail.

    This test might fail if cache is not allocated correctly for n > 1
    decoding steps inside a chunked prefill forward pass
    (where we have both prefill and decode together)
    """
152
    sampling_params = SamplingParams(n=3, temperature=1, seed=0, max_tokens=max_tokens)
153
    with vllm_runner(
154
155
156
157
158
        model,
        enable_chunked_prefill=True,
        # forces prefill chunks with decoding
        max_num_batched_tokens=MAX_NUM_SEQS * 3,
        max_num_seqs=MAX_NUM_SEQS,
159
160
    ) as vllm_model:
        vllm_model.generate(example_prompts, sampling_params)
161
162


163
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
164
165
166
167
168
169
170
@pytest.mark.parametrize("max_tokens", [20])
def test_mamba_cache_cg_padding(
    vllm_runner,
    example_prompts,
    model: str,
    max_tokens: int,
) -> None:
171
172
173
174
175
    """
    This test is for verifying that mamba cache is padded to CG captured
    batch size. If it's not, a torch RuntimeError will be raised because
    tensor dimensions aren't compatible.
    """
176
    vllm_config = EngineArgs(model=model, trust_remote_code=True).create_engine_config()
177
178
179
180
181
182
183
184
    cudagraph_dispatcher = CudagraphDispatcher(vllm_config)
    cudagraph_dispatcher.initialize_cudagraph_keys(
        vllm_config.compilation_config.cudagraph_mode
    )
    while (
        len(example_prompts)
        == cudagraph_dispatcher.dispatch(len(example_prompts))[1].num_tokens
    ):
185
186
187
        example_prompts.append(example_prompts[0])

    try:
188
        with vllm_runner(model) as vllm_model:
189
190
191
192
193
            vllm_model.generate_greedy(example_prompts, max_tokens)
    except RuntimeError:
        pytest.fail(
            "Couldn't run batch size which is not equal to a Cuda Graph "
            "captured batch size. "
194
195
            "Could be related to mamba cache not padded correctly"
        )
196
197


198
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
199
200
201
def test_fail_upon_inc_requests_and_finished_requests_lt_available_blocks(
    vllm_runner,
    example_prompts,
202
    model: str,
203
) -> None:
204
205
206
207
208
209
    """
    This test is for verifying that the hybrid inner state management doesn't
    collapse in case where the number of incoming requests and
    finished_requests_ids is larger than the maximum mamba block capacity.

    This could generally happen due to the fact that hybrid does support
210
    statelessness mechanism where it can clean up new incoming requests in
211
212
    a single step.
    """
213
    try:
214
        with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
215
216
            vllm_model.generate_greedy([example_prompts[0]] * 100, 10)
    except ValueError:
217
218
219
220
        pytest.fail(
            "Hybrid inner state wasn't cleaned up properly between"
            "steps finished requests registered unnecessarily "
        )
221
222


223
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
Mor Zusman's avatar
Mor Zusman committed
224
225
226
def test_state_cleanup(
    vllm_runner,
    example_prompts,
227
    model: str,
Mor Zusman's avatar
Mor Zusman committed
228
) -> None:
229
    """
230
231
    This test is for verifying that the Hybrid state is cleaned up between
    steps.
232

233
    If it's not cleaned, an error would be expected.
234
    """
Mor Zusman's avatar
Mor Zusman committed
235
    try:
236
        with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
Mor Zusman's avatar
Mor Zusman committed
237
238
239
            for _ in range(10):
                vllm_model.generate_greedy([example_prompts[0]] * 100, 1)
    except ValueError:
240
241
242
243
        pytest.fail(
            "Hybrid inner state wasn't cleaned up between states, "
            "could be related to finished_requests_ids"
        )
Mor Zusman's avatar
Mor Zusman committed
244
245


246
@multi_gpu_test(num_gpus=2)
247
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
248
@pytest.mark.parametrize("max_tokens", [64])
249
250
@pytest.mark.parametrize("num_logprobs", [5])
def test_distributed_correctness(
251
252
253
254
    vllm_runner,
    example_prompts,
    model: str,
    max_tokens: int,
255
    num_logprobs: int,
256
) -> None:
257
258
259
    with vllm_runner(
        model, tensor_parallel_size=1, max_num_seqs=MAX_NUM_SEQS
    ) as vllm_model:
260
        vllm_outputs_tp_1 = vllm_model.generate_greedy_logprobs(
261
262
            example_prompts, max_tokens, num_logprobs
        )
263

264
265
266
    with vllm_runner(
        model, tensor_parallel_size=2, max_num_seqs=MAX_NUM_SEQS
    ) as vllm_model:
267
        vllm_outputs_tp_2 = vllm_model.generate_greedy_logprobs(
268
269
            example_prompts, max_tokens, num_logprobs
        )
270

271
    check_logprobs_close(
272
273
274
275
276
        outputs_0_lst=vllm_outputs_tp_1,
        outputs_1_lst=vllm_outputs_tp_2,
        name_0="vllm_tp_1",
        name_1="vllm_tp_2",
    )
277
278


279
@pytest.mark.parametrize("model", FULL_CUDA_GRAPH_MODELS)
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
def test_full_cuda_graph(
    hf_runner,
    vllm_runner,
    example_prompts,
    monkeypatch,
    model: str,
    max_tokens: int,
    num_logprobs: int,
) -> None:
    try:
        model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
        model_info.check_available_online(on_fail="skip")
        model_info.check_transformers_version(on_fail="skip")
    except ValueError:
        pass

    with hf_runner(model) as hf_model:
299
        hf_outputs = hf_model.generate_greedy_logprobs_limit(
300
301
            example_prompts, max_tokens, num_logprobs
        )
302

303
    with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
304
        vllm_outputs = vllm_model.generate_greedy_logprobs(
305
306
            example_prompts, max_tokens, num_logprobs
        )
307
308

    check_logprobs_close(
309
        outputs_0_lst=hf_outputs,
310
        outputs_1_lst=vllm_outputs,
311
        name_0="hf",
312
        name_1="vllm",
313
    )
314
315


316
@pytest.mark.parametrize("model", FP32_STATE_MODELS)
317
318
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
319
320
321
@pytest.mark.parametrize(
    "cache_dtype_param", ["mamba_ssm_cache_dtype", "mamba_cache_dtype"]
)
322
def test_fp32_cache_state(
323
324
325
326
327
328
329
    hf_runner,
    vllm_runner,
    example_prompts,
    monkeypatch,
    model: str,
    max_tokens: int,
    num_logprobs: int,
330
    cache_dtype_param: str,
331
332
333
334
335
336
337
338
339
) -> None:
    try:
        model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
        model_info.check_available_online(on_fail="skip")
        model_info.check_transformers_version(on_fail="skip")
    except ValueError:
        pass

    with hf_runner(model) as hf_model:
340
        hf_outputs = hf_model.generate_greedy_logprobs_limit(
341
342
            example_prompts, max_tokens, num_logprobs
        )
343

344
345
346
    with vllm_runner(
        model, max_num_seqs=MAX_NUM_SEQS, **{cache_dtype_param: "float32"}
    ) as vllm_model:
347
        vllm_outputs = vllm_model.generate_greedy_logprobs(
348
349
            example_prompts, max_tokens, num_logprobs
        )
350

351
    check_logprobs_close(
352
        outputs_0_lst=hf_outputs,
353
        outputs_1_lst=vllm_outputs,
354
        name_0="hf",
355
        name_1="vllm",
356
    )
357
358
359


# Helper functions for the APC tests
360
361
362
363
364
def _get_vllm_runner_params(
    model: str,
    max_model_len: int,
    tensor_parallel_size: int = 1,
):
365
    return {
366
        "model_name": model,
367
        "enable_chunked_prefill": True,
368
369
370
371
        "enable_prefix_caching": False,
        "max_model_len": max_model_len,
        "tensor_parallel_size": tensor_parallel_size,
        "gpu_memory_utilization": 0.4,
372
373
374
    }


375
376
377
378
379
380
381
382
383
def _get_vLLM_output(
    vllm_runner,
    kwargs,
    prompts,
    max_tokens,
    num_logprobs,
    num_repetitions=1,
    vllm_model=None,
):
384
385
386
387
388
389
390
391
    outs = []
    if vllm_model is None:
        vllm_model = vllm_runner(**kwargs)
    for _ in range(num_repetitions):
        if num_logprobs < 0:
            vllm_output = vllm_model.generate_greedy(prompts, max_tokens)
        else:
            vllm_output = vllm_model.generate_greedy_logprobs(
392
393
                prompts, max_tokens, num_logprobs
            )
394
395
396
397
398
        outs.append(vllm_output)

    return outs, vllm_model


399
@pytest.mark.parametrize("model", [HYBRID_MODELS[0], HYBRID_MODELS[3]])
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("n_repetitions", [2])
# If num_logprobs is set to -1, then the stringent version
# of the test is executed using `check_outputs_equal`
# instead of `check_logprobs_close`
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("tensor_parallel_size", [1])
def test_apc_single_prompt(
    hf_runner,
    vllm_runner,
    example_prompts,
    monkeypatch,
    model: str,
    max_tokens: int,
    n_repetitions: int,
    num_logprobs: int,
    tensor_parallel_size: int,
) -> None:
    try:
        model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
        model_info.check_available_online(on_fail="skip")
        model_info.check_transformers_version(on_fail="skip")
    except ValueError:
        pass

425
426
427
    compare_operator: Callable = (
        check_logprobs_close if num_logprobs > 0 else check_outputs_equal  # type: ignore
    )
428
429

    # Sample prompts.
430
    generated_prompts = [APC_MULTIPLY_BY * example_prompts[0]]
431

432
    max_model_len = max(len(prompt) + max_tokens for prompt in generated_prompts)
433
    vllm_runner_kwargs = _get_vllm_runner_params(
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
        model, max_model_len, tensor_parallel_size=tensor_parallel_size
    )
    vllm_runner_kwargs["mamba_ssm_cache_dtype"] = "float32"
    vllm_outputs_no_cache, _ = _get_vLLM_output(
        vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, num_logprobs
    )

    vllm_runner_kwargs["enable_prefix_caching"] = True
    vllm_outputs_cache_rep, _ = _get_vLLM_output(
        vllm_runner,
        vllm_runner_kwargs,
        generated_prompts,
        max_tokens,
        num_logprobs,
        n_repetitions,
    )
450
451
452
453
454
455
456
457
458
459
460
461
462

    for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep):
        # In the first repetition, the caches are filled
        # In the second repetition, these caches are reused

        compare_operator(
            outputs_0_lst=vllm_outputs_no_cache[0],
            outputs_1_lst=vllm_outputs_cache_itn,
            name_0="vllm_no_cache",
            name_1=f"vllm_cache_it_{r_idx + 1}",
        )


463
@pytest.mark.parametrize("model", [HYBRID_MODELS[0], HYBRID_MODELS[3]])
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("n_repetitions", [2])
# If num_logprobs is set to -1, then the stringent version
# of the test is executed using `check_outputs_equal`
# instead of `check_logprobs_close`
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("tensor_parallel_size", [1])
def test_apc_single_prompt_block_align_alignment(
    hf_runner,
    vllm_runner,
    example_prompts,
    monkeypatch,
    model: str,
    max_tokens: int,
    n_repetitions: int,
    num_logprobs: int,
    tensor_parallel_size: int,
) -> None:
    try:
        model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
        model_info.check_available_online(on_fail="skip")
        model_info.check_transformers_version(on_fail="skip")
    except ValueError:
        pass

489
490
491
    compare_operator: Callable = (
        check_logprobs_close if num_logprobs > 0 else check_outputs_equal  # type: ignore
    )
492
493

    # Sample prompts. This custom prompt is used, as it causes the most issues
494
    generated_prompts = ["The president of the United States is " * APC_MULTIPLY_BY]
495

496
    max_model_len = max(len(prompt) + max_tokens for prompt in generated_prompts)
497
    vllm_runner_kwargs = _get_vllm_runner_params(
498
499
500
        model, max_model_len, tensor_parallel_size=tensor_parallel_size
    )
    vllm_runner_kwargs["mamba_ssm_cache_dtype"] = "float32"
501

502
503
504
    vllm_outputs_no_cache, _ = _get_vLLM_output(
        vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, num_logprobs
    )
505

506
    vllm_runner_kwargs["enable_prefix_caching"] = True
507
508
    with vllm_runner(**vllm_runner_kwargs) as vllm_model:
        # Retrieve the default mamba state block size
509
        mamba_block_size = vllm_model.llm.llm_engine.cache_config.mamba_block_size
510
511
512
513
514
515
516

    # In case the hybrid model does not have the
    # "mamba_block_size" assume a fixed constant
    if mamba_block_size is None:
        mamba_block_size = 512

    mamba_block_size_multiplier = 10
517
518
519
520
521
522
523
524
525
526
527
528
    for offsets in [-3, 3, mamba_block_size // 4 + 3, mamba_block_size // 2 - 3]:
        vllm_runner_kwargs["max_num_batched_tokens"] = (
            mamba_block_size_multiplier * mamba_block_size - offsets
        )
        vllm_outputs_cache_rep, _ = _get_vLLM_output(
            vllm_runner,
            vllm_runner_kwargs,
            generated_prompts,
            max_tokens,
            num_logprobs,
            n_repetitions,
        )
529
530
531
532
533
534
535
536
537
538
539
540
541
542

        # Check alignment of the output logits when using APC
        for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep):
            # In the first repetition, the caches are filled
            # In the second repetition, these caches are reused

            compare_operator(
                outputs_0_lst=vllm_outputs_no_cache[0],
                outputs_1_lst=vllm_outputs_cache_itn,
                name_0="vllm_no_cache",
                name_1=f"vllm_cache_it_{r_idx + 1}",
            )


543
@pytest.mark.parametrize("model", [HYBRID_MODELS[0], HYBRID_MODELS[3]])
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("n_repetitions", [2])
# If num_logprobs is set to -1, then the stringent version
# of the test is executed using `check_outputs_equal`
# instead of `check_logprobs_close`
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("tensor_parallel_size", [1])
def test_apc_multiple_prompts_all_cached_outputs(
    hf_runner,
    vllm_runner,
    example_prompts,
    monkeypatch,
    model: str,
    max_tokens: int,
    n_repetitions: int,
    num_logprobs: int,
    tensor_parallel_size: int,
) -> None:
    try:
        model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
        model_info.check_available_online(on_fail="skip")
        model_info.check_transformers_version(on_fail="skip")
    except ValueError:
        pass

569
570
571
    compare_operator: Callable = (
        check_logprobs_close if num_logprobs > 0 else check_outputs_equal  # type: ignore
    )
572
573

    # Sample prompts.
574
    generated_prompts = [APC_MULTIPLY_BY * prompt for prompt in example_prompts]
575

576
    max_model_len = max(len(prompt) + max_tokens for prompt in generated_prompts)
577
    vllm_runner_kwargs = _get_vllm_runner_params(
578
579
580
        model, max_model_len, tensor_parallel_size=tensor_parallel_size
    )
    vllm_runner_kwargs["mamba_ssm_cache_dtype"] = "float32"
581
582
583
584
    # Reduce the effects of batch variance on ROCm since batch invariance is not
    # yet supported. See: https://github.com/vllm-project/vllm/issues/27433
    if current_platform.is_rocm():
        vllm_runner_kwargs["max_num_seqs"] = 4
585

586
587
588
    vllm_outputs_no_cache, _ = _get_vLLM_output(
        vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, num_logprobs
    )
589

590
591
592
593
594
595
596
597
598
    vllm_runner_kwargs["enable_prefix_caching"] = True
    vllm_outputs_cache_rep, _ = _get_vLLM_output(
        vllm_runner,
        vllm_runner_kwargs,
        generated_prompts,
        max_tokens,
        num_logprobs,
        n_repetitions,
    )
599
600
601
602
603
604
605
606
607
608
609
610
611

    for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep):
        # In the first repetition, the caches are filled
        # In the second repetition, these caches are reused

        compare_operator(
            outputs_0_lst=vllm_outputs_no_cache[0],
            outputs_1_lst=vllm_outputs_cache_itn,
            name_0="vllm_no_cache",
            name_1=f"vllm_cache_it_{r_idx + 1}",
        )


612
@pytest.mark.parametrize("model", [HYBRID_MODELS[0], HYBRID_MODELS[3]])
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("n_repetitions", [2])
# If num_logprobs is set to -1, then the stringent version
# of the test is executed using `check_outputs_equal`
# instead of `check_logprobs_close`
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("tensor_parallel_size", [1])
def test_apc_multiple_prompts_block_align_alignment(
    hf_runner,
    vllm_runner,
    example_prompts,
    monkeypatch,
    model: str,
    max_tokens: int,
    n_repetitions: int,
    num_logprobs: int,
    tensor_parallel_size: int,
) -> None:
    try:
        model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
        model_info.check_available_online(on_fail="skip")
        model_info.check_transformers_version(on_fail="skip")
    except ValueError:
        pass

638
639
640
    compare_operator: Callable = (
        check_logprobs_close if num_logprobs > 0 else check_outputs_equal  # type: ignore
    )
641
642
643
644

    # Sample prompts. This custom prompt is used, as it causes the most issues
    prompt_text = "The president of the United States is "
    prompt_offsets = [0, 3, 7, 13, 17, 22, 25, 31]
645
646
647
    generated_prompts = [
        prompt_text[offset:] * APC_MULTIPLY_BY for offset in prompt_offsets
    ]
648
649
650
651
652
653
654
655
656
657
658
659

    max_model_len = max(len(prompt) + max_tokens for prompt in generated_prompts)
    vllm_runner_kwargs = _get_vllm_runner_params(
        model, max_model_len, tensor_parallel_size
    )
    vllm_runner_kwargs["mamba_ssm_cache_dtype"] = "float32"

    vllm_outputs_no_cache, _ = _get_vLLM_output(
        vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, num_logprobs
    )

    vllm_runner_kwargs["enable_prefix_caching"] = True
660
661
    with vllm_runner(**vllm_runner_kwargs) as vllm_model:
        # Retrieve the default mamba state block size
662
        mamba_block_size = vllm_model.llm.llm_engine.cache_config.mamba_block_size
663
664
665
666
667
668
669

    # In case the hybrid model does not have the
    # "mamba_block_size" assume a fixed constant
    if mamba_block_size is None:
        mamba_block_size = 512

    mamba_block_size_multiplier = 10
670
671
672
673
674
675
676
677
678
679
680
681
    for offsets in [-3, 3, mamba_block_size // 4 + 3, mamba_block_size // 2 - 3]:
        vllm_runner_kwargs["max_num_batched_tokens"] = (
            mamba_block_size_multiplier * mamba_block_size - offsets
        )
        vllm_outputs_cache_rep, _ = _get_vLLM_output(
            vllm_runner,
            vllm_runner_kwargs,
            generated_prompts,
            max_tokens,
            num_logprobs,
            n_repetitions,
        )
682
683
684
685
686
687
688
689
690
691
692
693
694
695

        # Check alignment of the output logits when using APC
        for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep):
            # In the first repetition, the caches are filled
            # In the second repetition, these caches are reused

            compare_operator(
                outputs_0_lst=vllm_outputs_no_cache[0],
                outputs_1_lst=vllm_outputs_cache_itn,
                name_0="vllm_no_cache",
                name_1=f"vllm_cache_it_{r_idx + 1}",
            )


696
@pytest.mark.parametrize("model", [HYBRID_MODELS[0], HYBRID_MODELS[3]])
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("n_repetitions", [2])
# If num_logprobs is set to -1, then the stringent version
# of the test is executed using `check_outputs_equal`
# instead of `check_logprobs_close`
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("tensor_parallel_size", [1])
def test_apc_multiple_prompts_partial_cached_outputs(
    hf_runner,
    vllm_runner,
    example_prompts,
    monkeypatch,
    model: str,
    max_tokens: int,
    n_repetitions: int,
    num_logprobs: int,
    tensor_parallel_size: int,
) -> None:
    try:
        model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
        model_info.check_available_online(on_fail="skip")
        model_info.check_transformers_version(on_fail="skip")
    except ValueError:
        pass

722
723
724
    compare_operator: Callable = (
        check_logprobs_close if num_logprobs > 0 else check_outputs_equal  # type: ignore
    )
725
726

    # Sample prompts.
727
    generated_prompts = [APC_MULTIPLY_BY * prompt for prompt in example_prompts]
728

729
    max_model_len = max(len(prompt) + max_tokens for prompt in generated_prompts)
730
    vllm_runner_kwargs = _get_vllm_runner_params(
731
732
733
        model, max_model_len, tensor_parallel_size=tensor_parallel_size
    )
    vllm_runner_kwargs["mamba_ssm_cache_dtype"] = "float32"
734

735
736
737
    vllm_outputs_no_cache, _ = _get_vLLM_output(
        vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, num_logprobs
    )
738
739

    # Cache only part of all the prompts
740
    vllm_runner_kwargs["enable_prefix_caching"] = True
741
    vllm_outputs_partial_cache, vllm_model = _get_vLLM_output(
742
743
        vllm_runner, vllm_runner_kwargs, generated_prompts[:3], max_tokens, num_logprobs
    )
744
745
746
747
748
749
750
751

    compare_operator(
        outputs_0_lst=vllm_outputs_no_cache[0][:3],
        outputs_1_lst=vllm_outputs_partial_cache[0],
        name_0="vllm_no_cache",
        name_1="vllm_partial_cache",
    )

752
753
754
755
756
757
758
759
760
    vllm_outputs_cache_rep, _ = _get_vLLM_output(
        vllm_runner,
        vllm_runner_kwargs,
        generated_prompts,
        max_tokens,
        num_logprobs,
        n_repetitions,
        vllm_model=vllm_model,
    )
761
762
763
764
765
766
767
768
769
770
771

    for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep):
        # In the first repetition, the caches are filled
        # In the second repetition, these caches are reused

        compare_operator(
            outputs_0_lst=vllm_outputs_no_cache[0],
            outputs_1_lst=vllm_outputs_cache_itn,
            name_0="vllm_no_cache",
            name_1=f"vllm_cache_it_{r_idx + 1}",
        )