test_hybrid.py 27.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from collections.abc import Callable
5

Mor Zusman's avatar
Mor Zusman committed
6
7
import pytest

8
from tests.models.registry import HF_EXAMPLE_MODELS
9
from tests.utils import multi_gpu_test
10
from vllm import LLM
11
from vllm.engine.arg_utils import EngineArgs
12
from vllm.platforms import current_platform
13
from vllm.sampling_params import SamplingParams
14
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
15

16
from ...utils import check_logprobs_close, check_outputs_equal
17

18
19
20
# Mark all tests as hybrid
pytestmark = pytest.mark.hybrid_model

21
22
23
24
# NOTE: The first model in each list is taken as the primary model,
# meaning that it will be used in all tests in this file
# The rest of the models will only be tested by test_models

25
26
APC_MULTIPLY_BY = 300

27
28
29
SSM_MODELS = [
    "state-spaces/mamba-130m-hf",
    "tiiuae/falcon-mamba-tiny-dev",
30
31
    # mamba2-codestral in transformers is broken pending:
    # https://github.com/huggingface/transformers/pull/40861
32
    # "yujiepan/mamba2-codestral-v0.1-tiny-random",
33
]
34

35
36
HYBRID_MODELS = [
    "ai21labs/Jamba-tiny-dev",
37
    "pfnet/plamo-2-1b",
38
    "Zyphra/Zamba2-1.2B-instruct",
39
    "hmellor/tiny-random-BambaForCausalLM",
40
41
    "ibm-granite/granite-4.0-tiny-preview",
    "tiiuae/Falcon-H1-0.5B-Base",
42
    "LiquidAI/LFM2-1.2B",
43
    "tiny-random/qwen3-next-moe",
Chen Zhang's avatar
Chen Zhang committed
44
45
]

46
FULL_CUDA_GRAPH_MODELS = [
47
    "ai21labs/Jamba-tiny-dev",
48
    "pfnet/plamo-2-1b",
49
    "Zyphra/Zamba2-1.2B-instruct",
50
51
]

52
53
54
55
56
FP32_STATE_MODELS = [
    "state-spaces/mamba-130m-hf",
    "Zyphra/Zamba2-1.2B-instruct",
]

57
58
59
# Avoid OOM
MAX_NUM_SEQS = 4

60
61
ATTN_BACKEND = "TRITON_ATTN" if current_platform.is_rocm() else "auto"

Mor Zusman's avatar
Mor Zusman committed
62

63
64
65
@pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS)
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
Mor Zusman's avatar
Mor Zusman committed
66
67
68
69
def test_models(
    hf_runner,
    vllm_runner,
    example_prompts,
Chen Zhang's avatar
Chen Zhang committed
70
    monkeypatch,
Mor Zusman's avatar
Mor Zusman committed
71
72
    model: str,
    max_tokens: int,
73
    num_logprobs: int,
Mor Zusman's avatar
Mor Zusman committed
74
) -> None:
75
76
77
    try:
        model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
        model_info.check_available_online(on_fail="skip")
78
        model_info.check_transformers_version(on_fail="skip")
79
    except ValueError:
80
        pass
81

82
    with hf_runner(model) as hf_model:
83
        hf_outputs = hf_model.generate_greedy_logprobs_limit(
84
85
            example_prompts, max_tokens, num_logprobs
        )
Yu Chin Fabian Lim's avatar
Yu Chin Fabian Lim committed
86

87
88
89
    with vllm_runner(
        model, max_num_seqs=MAX_NUM_SEQS, attention_backend=ATTN_BACKEND
    ) as vllm_model:
90
        vllm_outputs = vllm_model.generate_greedy_logprobs(
91
92
            example_prompts, max_tokens, num_logprobs
        )
Chen Zhang's avatar
Chen Zhang committed
93

94
95
96
97
98
99
    check_logprobs_close(
        outputs_0_lst=hf_outputs,
        outputs_1_lst=vllm_outputs,
        name_0="hf",
        name_1="vllm",
    )
Mor Zusman's avatar
Mor Zusman committed
100
101


102
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
103
104
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
105
106
107
108
109
def test_batching(
    vllm_runner,
    example_prompts,
    model: str,
    max_tokens: int,
110
    num_logprobs: int,
111
) -> None:
112
113
114
115
116
117
118
    try:
        model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
        model_info.check_available_online(on_fail="skip")
        model_info.check_transformers_version(on_fail="skip")
    except ValueError:
        pass

119
    for_loop_outputs = []
120
    with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
121
        for prompt in example_prompts:
122
123
124
            (single_output,) = vllm_model.generate_greedy_logprobs(
                [prompt], max_tokens, num_logprobs
            )
125
            for_loop_outputs.append(single_output)
126

127
        batched_outputs = vllm_model.generate_greedy_logprobs(
128
129
            example_prompts, max_tokens, num_logprobs
        )
130

131
    check_logprobs_close(
132
133
134
135
136
137
138
        outputs_0_lst=for_loop_outputs,
        outputs_1_lst=batched_outputs,
        name_0="for_loop_vllm",
        name_1="batched_vllm",
    )


139
140
141
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
@pytest.mark.parametrize("max_tokens", [10])
def test_chunked_prefill_with_parallel_sampling(
142
143
144
145
146
    vllm_runner,
    example_prompts,
    model: str,
    max_tokens: int,
) -> None:
147
    """
148
149
    Tests chunked prefill in conjunction with n > 1.

150
151
152
153
154
155
156
    In this case, prefill is populated with decoding tokens and
    we test that it doesn't fail.

    This test might fail if cache is not allocated correctly for n > 1
    decoding steps inside a chunked prefill forward pass
    (where we have both prefill and decode together)
    """
157
    sampling_params = SamplingParams(n=3, temperature=1, seed=0, max_tokens=max_tokens)
158
    with vllm_runner(
159
160
161
162
163
        model,
        enable_chunked_prefill=True,
        # forces prefill chunks with decoding
        max_num_batched_tokens=MAX_NUM_SEQS * 3,
        max_num_seqs=MAX_NUM_SEQS,
164
        attention_backend=ATTN_BACKEND,
165
166
    ) as vllm_model:
        vllm_model.generate(example_prompts, sampling_params)
167
168


169
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
170
171
172
173
174
175
176
@pytest.mark.parametrize("max_tokens", [20])
def test_mamba_cache_cg_padding(
    vllm_runner,
    example_prompts,
    model: str,
    max_tokens: int,
) -> None:
177
178
179
180
181
    """
    This test is for verifying that mamba cache is padded to CG captured
    batch size. If it's not, a torch RuntimeError will be raised because
    tensor dimensions aren't compatible.
    """
182
    vllm_config = EngineArgs(model=model, trust_remote_code=True).create_engine_config()
183
184
185
186
187
188
189
190
    cudagraph_dispatcher = CudagraphDispatcher(vllm_config)
    cudagraph_dispatcher.initialize_cudagraph_keys(
        vllm_config.compilation_config.cudagraph_mode
    )
    while (
        len(example_prompts)
        == cudagraph_dispatcher.dispatch(len(example_prompts))[1].num_tokens
    ):
191
192
193
        example_prompts.append(example_prompts[0])

    try:
194
        with vllm_runner(model) as vllm_model:
195
196
197
198
199
            vllm_model.generate_greedy(example_prompts, max_tokens)
    except RuntimeError:
        pytest.fail(
            "Couldn't run batch size which is not equal to a Cuda Graph "
            "captured batch size. "
200
201
            "Could be related to mamba cache not padded correctly"
        )
202
203


204
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
205
206
207
def test_fail_upon_inc_requests_and_finished_requests_lt_available_blocks(
    vllm_runner,
    example_prompts,
208
    model: str,
209
) -> None:
210
211
212
213
214
215
    """
    This test is for verifying that the hybrid inner state management doesn't
    collapse in case where the number of incoming requests and
    finished_requests_ids is larger than the maximum mamba block capacity.

    This could generally happen due to the fact that hybrid does support
216
    statelessness mechanism where it can clean up new incoming requests in
217
218
    a single step.
    """
219
    try:
220
        with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
221
222
            vllm_model.generate_greedy([example_prompts[0]] * 100, 10)
    except ValueError:
223
224
225
226
        pytest.fail(
            "Hybrid inner state wasn't cleaned up properly between"
            "steps finished requests registered unnecessarily "
        )
227
228


229
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
Mor Zusman's avatar
Mor Zusman committed
230
231
232
def test_state_cleanup(
    vllm_runner,
    example_prompts,
233
    model: str,
Mor Zusman's avatar
Mor Zusman committed
234
) -> None:
235
    """
236
237
    This test is for verifying that the Hybrid state is cleaned up between
    steps.
238

239
    If it's not cleaned, an error would be expected.
240
    """
Mor Zusman's avatar
Mor Zusman committed
241
    try:
242
        with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
Mor Zusman's avatar
Mor Zusman committed
243
244
245
            for _ in range(10):
                vllm_model.generate_greedy([example_prompts[0]] * 100, 1)
    except ValueError:
246
247
248
249
        pytest.fail(
            "Hybrid inner state wasn't cleaned up between states, "
            "could be related to finished_requests_ids"
        )
Mor Zusman's avatar
Mor Zusman committed
250
251


252
@multi_gpu_test(num_gpus=2)
253
@pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]])
254
@pytest.mark.parametrize("max_tokens", [64])
255
256
@pytest.mark.parametrize("num_logprobs", [5])
def test_distributed_correctness(
257
258
259
260
    vllm_runner,
    example_prompts,
    model: str,
    max_tokens: int,
261
    num_logprobs: int,
262
) -> None:
263
264
265
    with vllm_runner(
        model, tensor_parallel_size=1, max_num_seqs=MAX_NUM_SEQS
    ) as vllm_model:
266
        vllm_outputs_tp_1 = vllm_model.generate_greedy_logprobs(
267
268
            example_prompts, max_tokens, num_logprobs
        )
269

270
271
272
    with vllm_runner(
        model, tensor_parallel_size=2, max_num_seqs=MAX_NUM_SEQS
    ) as vllm_model:
273
        vllm_outputs_tp_2 = vllm_model.generate_greedy_logprobs(
274
275
            example_prompts, max_tokens, num_logprobs
        )
276

277
    check_logprobs_close(
278
279
280
281
282
        outputs_0_lst=vllm_outputs_tp_1,
        outputs_1_lst=vllm_outputs_tp_2,
        name_0="vllm_tp_1",
        name_1="vllm_tp_2",
    )
283
284


285
@pytest.mark.parametrize("model", FULL_CUDA_GRAPH_MODELS)
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
def test_full_cuda_graph(
    hf_runner,
    vllm_runner,
    example_prompts,
    monkeypatch,
    model: str,
    max_tokens: int,
    num_logprobs: int,
) -> None:
    try:
        model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
        model_info.check_available_online(on_fail="skip")
        model_info.check_transformers_version(on_fail="skip")
    except ValueError:
        pass

    with hf_runner(model) as hf_model:
305
        hf_outputs = hf_model.generate_greedy_logprobs_limit(
306
307
            example_prompts, max_tokens, num_logprobs
        )
308

309
310
311
    with vllm_runner(
        model, max_num_seqs=MAX_NUM_SEQS, attention_backend=ATTN_BACKEND
    ) as vllm_model:
312
        vllm_outputs = vllm_model.generate_greedy_logprobs(
313
314
            example_prompts, max_tokens, num_logprobs
        )
315
316

    check_logprobs_close(
317
        outputs_0_lst=hf_outputs,
318
        outputs_1_lst=vllm_outputs,
319
        name_0="hf",
320
        name_1="vllm",
321
    )
322
323


324
@pytest.mark.parametrize("model", FP32_STATE_MODELS)
325
326
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
327
328
329
@pytest.mark.parametrize(
    "cache_dtype_param", ["mamba_ssm_cache_dtype", "mamba_cache_dtype"]
)
330
def test_fp32_cache_state(
331
332
333
334
335
336
337
    hf_runner,
    vllm_runner,
    example_prompts,
    monkeypatch,
    model: str,
    max_tokens: int,
    num_logprobs: int,
338
    cache_dtype_param: str,
339
340
341
342
343
344
345
346
347
) -> None:
    try:
        model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
        model_info.check_available_online(on_fail="skip")
        model_info.check_transformers_version(on_fail="skip")
    except ValueError:
        pass

    with hf_runner(model) as hf_model:
348
        hf_outputs = hf_model.generate_greedy_logprobs_limit(
349
350
            example_prompts, max_tokens, num_logprobs
        )
351

352
353
354
    with vllm_runner(
        model, max_num_seqs=MAX_NUM_SEQS, **{cache_dtype_param: "float32"}
    ) as vllm_model:
355
        vllm_outputs = vllm_model.generate_greedy_logprobs(
356
357
            example_prompts, max_tokens, num_logprobs
        )
358

359
    check_logprobs_close(
360
        outputs_0_lst=hf_outputs,
361
        outputs_1_lst=vllm_outputs,
362
        name_0="hf",
363
        name_1="vllm",
364
    )
365
366
367


# Helper functions for the APC tests
368
369
370
371
372
def _get_vllm_runner_params(
    model: str,
    max_model_len: int,
    tensor_parallel_size: int = 1,
):
373
    return {
374
        "model_name": model,
375
        "enable_chunked_prefill": True,
376
377
378
379
        "enable_prefix_caching": False,
        "max_model_len": max_model_len,
        "tensor_parallel_size": tensor_parallel_size,
        "gpu_memory_utilization": 0.4,
380
        "attention_backend": ATTN_BACKEND,
381
382
383
    }


384
385
386
387
388
389
390
391
392
def _get_vLLM_output(
    vllm_runner,
    kwargs,
    prompts,
    max_tokens,
    num_logprobs,
    num_repetitions=1,
    vllm_model=None,
):
393
394
395
396
397
398
399
400
    outs = []
    if vllm_model is None:
        vllm_model = vllm_runner(**kwargs)
    for _ in range(num_repetitions):
        if num_logprobs < 0:
            vllm_output = vllm_model.generate_greedy(prompts, max_tokens)
        else:
            vllm_output = vllm_model.generate_greedy_logprobs(
401
402
                prompts, max_tokens, num_logprobs
            )
403
404
405
406
407
        outs.append(vllm_output)

    return outs, vllm_model


408
@pytest.mark.parametrize("model", [HYBRID_MODELS[0], HYBRID_MODELS[3]])
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("n_repetitions", [2])
# If num_logprobs is set to -1, then the stringent version
# of the test is executed using `check_outputs_equal`
# instead of `check_logprobs_close`
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("tensor_parallel_size", [1])
def test_apc_single_prompt(
    hf_runner,
    vllm_runner,
    example_prompts,
    monkeypatch,
    model: str,
    max_tokens: int,
    n_repetitions: int,
    num_logprobs: int,
    tensor_parallel_size: int,
) -> None:
    try:
        model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
        model_info.check_available_online(on_fail="skip")
        model_info.check_transformers_version(on_fail="skip")
    except ValueError:
        pass

434
435
436
    compare_operator: Callable = (
        check_logprobs_close if num_logprobs > 0 else check_outputs_equal  # type: ignore
    )
437
438

    # Sample prompts.
439
    generated_prompts = [APC_MULTIPLY_BY * example_prompts[0]]
440

441
    max_model_len = max(len(prompt) + max_tokens for prompt in generated_prompts)
442
    vllm_runner_kwargs = _get_vllm_runner_params(
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
        model, max_model_len, tensor_parallel_size=tensor_parallel_size
    )
    vllm_runner_kwargs["mamba_ssm_cache_dtype"] = "float32"
    vllm_outputs_no_cache, _ = _get_vLLM_output(
        vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, num_logprobs
    )

    vllm_runner_kwargs["enable_prefix_caching"] = True
    vllm_outputs_cache_rep, _ = _get_vLLM_output(
        vllm_runner,
        vllm_runner_kwargs,
        generated_prompts,
        max_tokens,
        num_logprobs,
        n_repetitions,
    )
459
460
461
462
463
464
465
466
467
468
469
470
471

    for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep):
        # In the first repetition, the caches are filled
        # In the second repetition, these caches are reused

        compare_operator(
            outputs_0_lst=vllm_outputs_no_cache[0],
            outputs_1_lst=vllm_outputs_cache_itn,
            name_0="vllm_no_cache",
            name_1=f"vllm_cache_it_{r_idx + 1}",
        )


472
@pytest.mark.parametrize("model", [HYBRID_MODELS[0], HYBRID_MODELS[3]])
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("n_repetitions", [2])
# If num_logprobs is set to -1, then the stringent version
# of the test is executed using `check_outputs_equal`
# instead of `check_logprobs_close`
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("tensor_parallel_size", [1])
def test_apc_single_prompt_block_align_alignment(
    hf_runner,
    vllm_runner,
    example_prompts,
    monkeypatch,
    model: str,
    max_tokens: int,
    n_repetitions: int,
    num_logprobs: int,
    tensor_parallel_size: int,
) -> None:
    try:
        model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
        model_info.check_available_online(on_fail="skip")
        model_info.check_transformers_version(on_fail="skip")
    except ValueError:
        pass

498
499
500
    compare_operator: Callable = (
        check_logprobs_close if num_logprobs > 0 else check_outputs_equal  # type: ignore
    )
501
502

    # Sample prompts. This custom prompt is used, as it causes the most issues
503
    generated_prompts = ["The president of the United States is " * APC_MULTIPLY_BY]
504

505
    max_model_len = max(len(prompt) + max_tokens for prompt in generated_prompts)
506
    vllm_runner_kwargs = _get_vllm_runner_params(
507
508
509
        model, max_model_len, tensor_parallel_size=tensor_parallel_size
    )
    vllm_runner_kwargs["mamba_ssm_cache_dtype"] = "float32"
510

511
512
513
    vllm_outputs_no_cache, _ = _get_vLLM_output(
        vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, num_logprobs
    )
514

515
    vllm_runner_kwargs["enable_prefix_caching"] = True
516
517
    with vllm_runner(**vllm_runner_kwargs) as vllm_model:
        # Retrieve the default mamba state block size
518
519
        vllm_config = vllm_model.llm.llm_engine.vllm_config
        mamba_block_size = vllm_config.cache_config.mamba_block_size
520
521
522
523
524
525
526

    # In case the hybrid model does not have the
    # "mamba_block_size" assume a fixed constant
    if mamba_block_size is None:
        mamba_block_size = 512

    mamba_block_size_multiplier = 10
527
528
529
530
531
532
533
534
535
536
537
538
    for offsets in [-3, 3, mamba_block_size // 4 + 3, mamba_block_size // 2 - 3]:
        vllm_runner_kwargs["max_num_batched_tokens"] = (
            mamba_block_size_multiplier * mamba_block_size - offsets
        )
        vllm_outputs_cache_rep, _ = _get_vLLM_output(
            vllm_runner,
            vllm_runner_kwargs,
            generated_prompts,
            max_tokens,
            num_logprobs,
            n_repetitions,
        )
539
540
541
542
543
544
545
546
547
548
549
550
551
552

        # Check alignment of the output logits when using APC
        for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep):
            # In the first repetition, the caches are filled
            # In the second repetition, these caches are reused

            compare_operator(
                outputs_0_lst=vllm_outputs_no_cache[0],
                outputs_1_lst=vllm_outputs_cache_itn,
                name_0="vllm_no_cache",
                name_1=f"vllm_cache_it_{r_idx + 1}",
            )


553
@pytest.mark.parametrize("model", [HYBRID_MODELS[0], HYBRID_MODELS[3]])
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("n_repetitions", [2])
# If num_logprobs is set to -1, then the stringent version
# of the test is executed using `check_outputs_equal`
# instead of `check_logprobs_close`
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("tensor_parallel_size", [1])
def test_apc_multiple_prompts_all_cached_outputs(
    hf_runner,
    vllm_runner,
    example_prompts,
    monkeypatch,
    model: str,
    max_tokens: int,
    n_repetitions: int,
    num_logprobs: int,
    tensor_parallel_size: int,
) -> None:
    try:
        model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
        model_info.check_available_online(on_fail="skip")
        model_info.check_transformers_version(on_fail="skip")
    except ValueError:
        pass

579
580
581
    compare_operator: Callable = (
        check_logprobs_close if num_logprobs > 0 else check_outputs_equal  # type: ignore
    )
582
583

    # Sample prompts.
584
    generated_prompts = [APC_MULTIPLY_BY * prompt for prompt in example_prompts]
585

586
    max_model_len = max(len(prompt) + max_tokens for prompt in generated_prompts)
587
    vllm_runner_kwargs = _get_vllm_runner_params(
588
589
590
        model, max_model_len, tensor_parallel_size=tensor_parallel_size
    )
    vllm_runner_kwargs["mamba_ssm_cache_dtype"] = "float32"
591
592
593
594
    # Reduce the effects of batch variance on ROCm since batch invariance is not
    # yet supported. See: https://github.com/vllm-project/vllm/issues/27433
    if current_platform.is_rocm():
        vllm_runner_kwargs["max_num_seqs"] = 4
595

596
597
598
    vllm_outputs_no_cache, _ = _get_vLLM_output(
        vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, num_logprobs
    )
599

600
601
602
603
604
605
606
607
608
    vllm_runner_kwargs["enable_prefix_caching"] = True
    vllm_outputs_cache_rep, _ = _get_vLLM_output(
        vllm_runner,
        vllm_runner_kwargs,
        generated_prompts,
        max_tokens,
        num_logprobs,
        n_repetitions,
    )
609
610
611
612
613
614
615
616
617
618
619
620
621

    for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep):
        # In the first repetition, the caches are filled
        # In the second repetition, these caches are reused

        compare_operator(
            outputs_0_lst=vllm_outputs_no_cache[0],
            outputs_1_lst=vllm_outputs_cache_itn,
            name_0="vllm_no_cache",
            name_1=f"vllm_cache_it_{r_idx + 1}",
        )


622
@pytest.mark.parametrize("model", [HYBRID_MODELS[0], HYBRID_MODELS[3]])
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("n_repetitions", [2])
# If num_logprobs is set to -1, then the stringent version
# of the test is executed using `check_outputs_equal`
# instead of `check_logprobs_close`
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("tensor_parallel_size", [1])
def test_apc_multiple_prompts_block_align_alignment(
    hf_runner,
    vllm_runner,
    example_prompts,
    monkeypatch,
    model: str,
    max_tokens: int,
    n_repetitions: int,
    num_logprobs: int,
    tensor_parallel_size: int,
) -> None:
    try:
        model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
        model_info.check_available_online(on_fail="skip")
        model_info.check_transformers_version(on_fail="skip")
    except ValueError:
        pass

648
649
650
    compare_operator: Callable = (
        check_logprobs_close if num_logprobs > 0 else check_outputs_equal  # type: ignore
    )
651
652
653
654

    # Sample prompts. This custom prompt is used, as it causes the most issues
    prompt_text = "The president of the United States is "
    prompt_offsets = [0, 3, 7, 13, 17, 22, 25, 31]
655
656
657
    generated_prompts = [
        prompt_text[offset:] * APC_MULTIPLY_BY for offset in prompt_offsets
    ]
658
659
660
661
662
663
664
665
666
667
668
669

    max_model_len = max(len(prompt) + max_tokens for prompt in generated_prompts)
    vllm_runner_kwargs = _get_vllm_runner_params(
        model, max_model_len, tensor_parallel_size
    )
    vllm_runner_kwargs["mamba_ssm_cache_dtype"] = "float32"

    vllm_outputs_no_cache, _ = _get_vLLM_output(
        vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, num_logprobs
    )

    vllm_runner_kwargs["enable_prefix_caching"] = True
670
671
    with vllm_runner(**vllm_runner_kwargs) as vllm_model:
        # Retrieve the default mamba state block size
672
673
        vllm_config = vllm_model.llm.llm_engine.vllm_config
        mamba_block_size = vllm_config.cache_config.mamba_block_size
674
675
676
677
678
679
680

    # In case the hybrid model does not have the
    # "mamba_block_size" assume a fixed constant
    if mamba_block_size is None:
        mamba_block_size = 512

    mamba_block_size_multiplier = 10
681
682
683
684
685
686
687
688
689
690
691
692
    for offsets in [-3, 3, mamba_block_size // 4 + 3, mamba_block_size // 2 - 3]:
        vllm_runner_kwargs["max_num_batched_tokens"] = (
            mamba_block_size_multiplier * mamba_block_size - offsets
        )
        vllm_outputs_cache_rep, _ = _get_vLLM_output(
            vllm_runner,
            vllm_runner_kwargs,
            generated_prompts,
            max_tokens,
            num_logprobs,
            n_repetitions,
        )
693
694
695
696
697
698
699
700
701
702
703
704
705
706

        # Check alignment of the output logits when using APC
        for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep):
            # In the first repetition, the caches are filled
            # In the second repetition, these caches are reused

            compare_operator(
                outputs_0_lst=vllm_outputs_no_cache[0],
                outputs_1_lst=vllm_outputs_cache_itn,
                name_0="vllm_no_cache",
                name_1=f"vllm_cache_it_{r_idx + 1}",
            )


707
@pytest.mark.parametrize("model", [HYBRID_MODELS[0], HYBRID_MODELS[3]])
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("n_repetitions", [2])
# If num_logprobs is set to -1, then the stringent version
# of the test is executed using `check_outputs_equal`
# instead of `check_logprobs_close`
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("tensor_parallel_size", [1])
def test_apc_multiple_prompts_partial_cached_outputs(
    hf_runner,
    vllm_runner,
    example_prompts,
    monkeypatch,
    model: str,
    max_tokens: int,
    n_repetitions: int,
    num_logprobs: int,
    tensor_parallel_size: int,
) -> None:
    try:
        model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
        model_info.check_available_online(on_fail="skip")
        model_info.check_transformers_version(on_fail="skip")
    except ValueError:
        pass

733
734
735
    compare_operator: Callable = (
        check_logprobs_close if num_logprobs > 0 else check_outputs_equal  # type: ignore
    )
736
737

    # Sample prompts.
738
    generated_prompts = [APC_MULTIPLY_BY * prompt for prompt in example_prompts]
739

740
    max_model_len = max(len(prompt) + max_tokens for prompt in generated_prompts)
741
    vllm_runner_kwargs = _get_vllm_runner_params(
742
743
744
        model, max_model_len, tensor_parallel_size=tensor_parallel_size
    )
    vllm_runner_kwargs["mamba_ssm_cache_dtype"] = "float32"
745

746
747
748
    vllm_outputs_no_cache, _ = _get_vLLM_output(
        vllm_runner, vllm_runner_kwargs, generated_prompts, max_tokens, num_logprobs
    )
749
750

    # Cache only part of all the prompts
751
    vllm_runner_kwargs["enable_prefix_caching"] = True
752
    vllm_outputs_partial_cache, vllm_model = _get_vLLM_output(
753
754
        vllm_runner, vllm_runner_kwargs, generated_prompts[:3], max_tokens, num_logprobs
    )
755
756
757
758
759
760
761
762

    compare_operator(
        outputs_0_lst=vllm_outputs_no_cache[0][:3],
        outputs_1_lst=vllm_outputs_partial_cache[0],
        name_0="vllm_no_cache",
        name_1="vllm_partial_cache",
    )

763
764
765
766
767
768
769
770
771
    vllm_outputs_cache_rep, _ = _get_vLLM_output(
        vllm_runner,
        vllm_runner_kwargs,
        generated_prompts,
        max_tokens,
        num_logprobs,
        n_repetitions,
        vllm_model=vllm_model,
    )
772
773
774
775
776
777
778
779
780
781
782

    for r_idx, vllm_outputs_cache_itn in enumerate(vllm_outputs_cache_rep):
        # In the first repetition, the caches are filled
        # In the second repetition, these caches are reused

        compare_operator(
            outputs_0_lst=vllm_outputs_no_cache[0],
            outputs_1_lst=vllm_outputs_cache_itn,
            name_0="vllm_no_cache",
            name_1=f"vllm_cache_it_{r_idx + 1}",
        )
783
784


785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
# Test that outputs match whether prefix caching is enabled or not for mamba.
@pytest.mark.parametrize("model", ["tiiuae/falcon-mamba-7b"])
def test_same_mamba_output_apc_on_vs_off(
    vllm_runner,
    model: str,
) -> None:
    num_logprobs = 5
    prompts = [
        "hello what is one plus one what is one plus one what is one plus one the answer is",  # noqa: E501
        "hello what is one plus one what is one plus one what is one plus one the answer is",  # noqa: E501
    ]
    max_tokens = 20
    max_model_len = max(len(p) for p in prompts) + max_tokens + 64

    base_kwargs = _get_vllm_runner_params(model, max_model_len)
    base_kwargs.update(
        enforce_eager=True, block_size=16, seed=42, gpu_memory_utilization=0.8
    )

    # No prefix caching
    kwargs_no_apc = {**base_kwargs, "enable_prefix_caching": False}
    with vllm_runner(**kwargs_no_apc) as vllm_model:
        outputs_no_apc, _ = _get_vLLM_output(
            vllm_runner,
            kwargs_no_apc,
            prompts,
            max_tokens,
            num_logprobs=num_logprobs,
            vllm_model=vllm_model,
        )
    # With prefix caching
    kwargs_with_apc = {
        **base_kwargs,
        "enable_prefix_caching": True,
        "mamba_block_size": 16,
    }
    with vllm_runner(**kwargs_with_apc) as vllm_model:
        outputs_with_apc, _ = _get_vLLM_output(
            vllm_runner,
            kwargs_with_apc,
            prompts,
            max_tokens,
            num_logprobs=num_logprobs,
            vllm_model=vllm_model,
        )

    check_logprobs_close(
        outputs_0_lst=outputs_no_apc[0],
        outputs_1_lst=outputs_with_apc[0],
        name_0="vllm_no_apc",
        name_1="vllm_with_apc",
    )


839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
# we have to use a real large model to get reasonable results
# the model can't be a hybrid model as we need block_size 16
@pytest.mark.parametrize("model", ["tiiuae/falcon-mamba-7b"])
def test_apc_common_prefix_same_batch(
    model: str,
    monkeypatch,
) -> None:
    # Required to put the two requests in the same batch
    monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
    llm = LLM(
        model=model,
        enforce_eager=True,
        block_size=16,
        mamba_block_size=16,
        enable_prefix_caching=True,
        seed=42,
855
        attention_backend=ATTN_BACKEND,
856
857
858
859
860
861
862
863
864
    )
    prompts = [
        "hello what is one plus one what is one plus one what is one plus one the answer is",  # noqa: E501
        "hello what is one plus one what is one plus one what is one plus one the answer is",  # noqa: E501
    ]
    sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=20)
    outputs = llm.generate(prompts, sampling_params)
    for output in outputs:
        assert "two" in output.outputs[0].text