test_config.py 21.8 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import copy
4
from contextlib import nullcontext
5
from unittest.mock import MagicMock, patch
6

7
import pytest
8
from pydantic import ValidationError
9
10

from vllm.compilation.counter import compilation_counter
11
12
13
from vllm.compilation.passes.utility.fix_functionalization import (
    FixFunctionalizationPass,
)
14
15
16
17
18
19
20
from vllm.config import (
    CompilationConfig,
    CUDAGraphMode,
    ParallelConfig,
    SchedulerConfig,
    VllmConfig,
)
21
from vllm.config.compilation import CompilationMode, PassConfig
22
23
from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform
24
25
26
27
from vllm.utils.torch_utils import (
    _is_torch_equal_or_newer,
    is_torch_equal,
)
28
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
29

30
31
32
# This import automatically registers `torch.ops.silly.attention`
from . import silly_attention  # noqa: F401

33
34

def test_version():
35
    # Test the version comparison logic using the private function
36
37
38
39
40
    assert _is_torch_equal_or_newer("2.8.0.dev20250624+cu128", "2.8.0.dev")
    assert _is_torch_equal_or_newer("2.8.0a0+gitc82a174", "2.8.0.dev")
    assert _is_torch_equal_or_newer("2.8.0", "2.8.0.dev")
    assert _is_torch_equal_or_newer("2.8.1", "2.8.0.dev")
    assert not _is_torch_equal_or_newer("2.7.1", "2.8.0.dev")
41
42


43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
def test_get_raw_stream_patch():
    """Test that get_raw_stream patch is applied only for torch 2.9.0 or 2.9.1."""
    import builtins

    # Check if get_raw_stream exists in builtins
    has_patch = hasattr(builtins, "get_raw_stream")

    # Import torch to get actual version

    is_torch_2_9 = is_torch_equal("2.9.0") or is_torch_equal("2.9.1")

    if is_torch_2_9:
        # For torch 2.9.x, the patch should be applied
        assert has_patch, "get_raw_stream should be patched for torch 2.9.x"
        # Verify it's callable (it should be the _cuda_getCurrentRawStream function)
        get_raw_stream = builtins.get_raw_stream  # type: ignore[attr-defined]
        assert callable(get_raw_stream)
        # Verify it's the correct function from torch._C
        from torch._C import _cuda_getCurrentRawStream

        assert get_raw_stream is _cuda_getCurrentRawStream


66
67
68
69
70
71
72
73
74
75
76
77
78
79
def test_copy_pass():
    vllm_config = VllmConfig()
    inductor_pass = FixFunctionalizationPass(vllm_config)
    copied_inductor_pass = copy.deepcopy(inductor_pass)
    assert (
        copied_inductor_pass.compilation_config.use_inductor_graph_partition
        == vllm_config.compilation_config.use_inductor_graph_partition
    )
    assert (
        copied_inductor_pass.compilation_config.splitting_ops
        == vllm_config.compilation_config.splitting_ops
    )


80
81
82
83
84
85
86
87
def test_custom_op():
    # proper syntax
    _ = CompilationConfig(custom_ops=["+quant_fp8", "-silu_and_mul"])

    with pytest.raises(ValueError, match="Invalid syntax '"):
        _ = CompilationConfig(custom_ops=["quant_fp8"])


88
89
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@pytest.mark.forked
90
91
92
93
94
# NB: We don't test VLLM_DISABLE_COMPILE_CACHE=0 because that depends
# on the state of the cache directory on the current machine, which
# may be influenced by other tests.
@pytest.mark.parametrize("val", ["1"])
def test_VLLM_DISABLE_COMPILE_CACHE(vllm_runner, monkeypatch, val):
95
    # Disable multiprocessing so that the counter is in the same process
96
97
    monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
    monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", val)
98
99

    compilation_config = {
100
        "cudagraph_mode": CUDAGraphMode.NONE,  # speed things up a bit
101
102
    }
    with (
103
104
105
106
107
108
109
110
111
112
        compilation_counter.expect(
            num_cache_entries_updated=0, num_compiled_artifacts_saved=0
        ),
        # loading the model causes compilation (if enabled) to happen
        vllm_runner(
            "facebook/opt-125m",
            compilation_config=compilation_config,
            gpu_memory_utilization=0.4,
        ) as _,
    ):
113
114
115
        pass


116
117
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@pytest.mark.forked
118
119
120
121
122
123
124
125
126
127
128
129
@pytest.mark.parametrize(
    "cudagraph_mode,num_cudagraph_captured",
    [
        (CUDAGraphMode.NONE, 0),
        (CUDAGraphMode.FULL_DECODE_ONLY, 1),
        (CUDAGraphMode.PIECEWISE, 13),
        (CUDAGraphMode.FULL_AND_PIECEWISE, 14),
    ],
)
def test_use_cudagraphs(
    vllm_runner, monkeypatch, cudagraph_mode, num_cudagraph_captured
):
130
    # Disable multiprocessing so that the counter is in the same process
131
    monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
132
133
134

    compilation_config = {
        "cudagraph_capture_sizes": [100],
135
        "cudagraph_mode": cudagraph_mode,
136
    }
137
    num_gpu_runner_capture_triggers = 1 if cudagraph_mode != CUDAGraphMode.NONE else 0
138
    with (
139
140
        compilation_counter.expect(
            num_graphs_seen=1,
141
142
            num_gpu_runner_capture_triggers=num_gpu_runner_capture_triggers,
            num_cudagraph_captured=num_cudagraph_captured,
143
144
145
146
147
148
149
150
        ),
        # loading the model causes compilation (if enabled) to happen
        vllm_runner(
            "facebook/opt-125m",
            compilation_config=compilation_config,
            gpu_memory_utilization=0.4,
        ) as _,
    ):
151
        pass
152
153
154
155


# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@pytest.mark.forked
156
def test_stock_torch_compile(vllm_runner, monkeypatch):
157
    # Disable multiprocessing so that the counter is in the same process
158
    monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
159
160

    with (
161
        compilation_counter.expect(stock_torch_compile_count=1),
162
163
164
        # loading the model causes compilation (if enabled) to happen
        vllm_runner(
            "facebook/opt-125m",
165
            compilation_config={"mode": CompilationMode.STOCK_TORCH_COMPILE},
166
167
168
            gpu_memory_utilization=0.4,
        ) as _,
    ):
169
170
171
172
173
174
175
        pass


# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@pytest.mark.forked
def test_no_compilation(vllm_runner, monkeypatch):
    # Disable multiprocessing so that the counter is in the same process
176
    monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
177
    with (
178
        compilation_counter.expect(num_graphs_seen=0, stock_torch_compile_count=0),
179
180
181
        # loading the model causes compilation (if enabled) to happen
        vllm_runner(
            "facebook/opt-125m",
182
            compilation_config={"mode": CompilationMode.NONE},
183
184
185
            gpu_memory_utilization=0.4,
        ) as _,
    ):
186
187
188
189
190
191
192
        pass


# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@pytest.mark.forked
def test_enforce_eager(vllm_runner, monkeypatch):
    # Disable multiprocessing so that the counter is in the same process
193
    monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
194
195

    with (
196
        compilation_counter.expect(num_graphs_seen=0, stock_torch_compile_count=0),
197
198
199
200
201
        # loading the model causes compilation (if enabled) to happen
        vllm_runner(
            "facebook/opt-125m", enforce_eager=True, gpu_memory_utilization=0.4
        ) as _,
    ):
202
        pass
203
204
205
206
207


def test_splitting_ops_dynamic():
    # Default config
    config = VllmConfig()
208
209
    # Default V1 config leaves cudagraph mode unset; splitting ops are only
    # populated when the engine decides to use piecewise compilation.
210
211
    assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE
    assert config.compilation_config.splitting_ops_contain_attention()
212
213

    # When use_inductor_graph_partition=True
214
215
216
217
218
    config = VllmConfig(
        compilation_config=CompilationConfig(
            mode=CompilationMode.VLLM_COMPILE,
            use_inductor_graph_partition=True,
            splitting_ops=["vllm::unified_attention"],
219
        )
220
221
222
223
    )
    # with inductor partition we use splitting_ops directly for
    # partition rules
    assert config.compilation_config.splitting_ops == ["vllm::unified_attention"]
224

225
    # When attn_fusion pass enabled.
226
227
    config = VllmConfig(
        compilation_config=CompilationConfig(
228
            mode=CompilationMode.VLLM_COMPILE,
229
            pass_config=PassConfig(fuse_attn_quant=True, eliminate_noops=True),
230
231
232
233
            custom_ops=["+quant_fp8"],
            cudagraph_mode=CUDAGraphMode.PIECEWISE,
        )
    )
234
235
236
    assert config.compilation_config.splitting_ops == []
    # cudagraph mode also fall back to FULL
    assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL
237

238
239
240
    # splitting_ops can not contain attention ops when attn_fusion
    # pass enabled.
    with pytest.raises(ValidationError):
241
242
        config = VllmConfig(
            compilation_config=CompilationConfig(
243
                mode=CompilationMode.VLLM_COMPILE,
244
                pass_config=PassConfig(fuse_attn_quant=True, eliminate_noops=True),
245
246
                custom_ops=["+quant_fp8"],
                cudagraph_mode=CUDAGraphMode.PIECEWISE,
247
248
                # work around for accessing all attntion ops
                splitting_ops=CompilationConfig()._attention_ops,
249
250
            )
        )
251
252
253
254
255
256

    # When both use_inductor_graph_partition and attn_fusion pass enabled.
    config = VllmConfig(
        compilation_config=CompilationConfig(
            mode=CompilationMode.VLLM_COMPILE,
            use_inductor_graph_partition=True,
257
            pass_config=PassConfig(fuse_attn_quant=True, eliminate_noops=True),
258
259
260
261
262
263
264
            custom_ops=["+quant_fp8"],
            cudagraph_mode=CUDAGraphMode.PIECEWISE,
        )
    )
    # With inductor graph partition, attn_fusion and splitting_ops
    # work together. Default splitting_ops include attention ops.
    assert config.compilation_config.splitting_ops_contain_attention()
265
    # fuse_attn_quant is directly supported under
266
267
268
    # use_inductor_graph_partition=True, and cudagraph_mode
    # is unchanged.
    assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE
269
270


271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
def test_moe_splitting_ops_deepep_ht_inductor_partition():
    # Inductor partition case: user-provided splitting_ops should be
    # preserved and MoE ops should be appended for DeepEP HT with dp>1.
    config = VllmConfig(
        parallel_config=ParallelConfig(
            all2all_backend="deepep_high_throughput",
            data_parallel_size=8,
        ),
        compilation_config=CompilationConfig(
            mode=CompilationMode.VLLM_COMPILE,
            use_inductor_graph_partition=True,
            splitting_ops=[
                "vllm::unified_attention",
                "vllm::moe_forward",
                "vllm::moe_forward_shared",
            ],
        ),
    )
    splitting_ops = config.compilation_config.splitting_ops
    assert splitting_ops == [
        "vllm::unified_attention",
        "vllm::moe_forward",
        "vllm::moe_forward_shared",
    ]


297
def test_should_split():
298
299
    import torch

300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
    from vllm.compilation.partition_rules import should_split

    graph = torch.fx.Graph()
    node = torch.fx.Node(
        graph=graph,
        name="dummy_node",
        op="call_function",
        target=torch.ops.aten.add.default,
        args=(),
        kwargs={},
    )

    # supports OpOverloadPacket
    splitting_ops = ["aten::add"]
    assert should_split(node, splitting_ops)

    # supports OpOverload
    splitting_ops = ["aten::add.default"]
    assert should_split(node, splitting_ops)

    # supports OpOverload
    splitting_ops = ["aten::add.Tensor"]
    assert not should_split(node, splitting_ops)

    q, k, v, out = [torch.randn(1)] * 4

    # supports custom ops as OpOverloadPacket
    node = torch.fx.Node(
        graph=graph,
        name="dummy_node",
        op="call_function",
        target=torch.ops.silly.attention,
        args=(q, k, v, out),
        kwargs={},
    )

    splitting_ops = ["silly::attention"]
    assert should_split(node, splitting_ops)

    # supports custom ops as OpOverload
    node = torch.fx.Node(
        graph=graph,
        name="dummy_node",
        op="call_function",
        target=torch.ops.silly.attention.default,
        args=(q, k, v, out),
        kwargs={},
    )

    splitting_ops = ["silly::attention"]
    assert should_split(node, splitting_ops)

    splitting_ops = ["silly::attention.default"]
    assert should_split(node, splitting_ops)
354
355
356
357
358
359
360
361
362
363
364


@pytest.mark.skipif(
    not current_platform.support_static_graph_mode(),
    reason="Skip if not cudagraph mode supported",
)
@pytest.mark.parametrize(
    (
        "cudagraph_capture_sizes",
        "max_cudagraph_capture_size",
        "tp_size",
365
        "enable_sp",
366
        "max_num_batched_tokens",
367
        "cudagraph_mode",
368
369
370
        "expected_max_size",
    ),
    [
371
372
373
374
375
376
377
378
379
380
381
382
383
384
        (None, None, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 256),
        ([1, 2, 4], 4, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 4),
        (
            [1, 2, 4],
            8,
            1,
            False,
            2048,
            CUDAGraphMode.FULL_AND_PIECEWISE,
            ValidationError,
        ),
        ([1, 256], None, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 256),
        ([], None, 1, False, 2048, CUDAGraphMode.NONE, 0),
        (None, 0, 1, False, 2048, CUDAGraphMode.NONE, 0),
385
        # truncated to nearest multiple of 8 or 16
386
387
388
389
390
391
392
        (None, 257, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 256),
        # max from list
        ([1, 2, 4, 15], None, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 15),
        # filtered out 15 due to SP
        ([1, 2, 4, 15], None, 2, True, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 4),
        # limited by the max_tokens
        ([1, 2, 4, 15], None, 1, False, 8, CUDAGraphMode.FULL_AND_PIECEWISE, 4),
393
        # the list should contain at least 1 element when use cudagraph
394
        ([], None, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, ValidationError),
395
        # the max capturing size should be >= 1 when use cudagraph
396
        (None, 0, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, ValidationError),
397
398
399
400
401
402
    ],
)
def test_cudagraph_sizes_post_init(
    cudagraph_capture_sizes,
    max_cudagraph_capture_size,
    tp_size,
403
    enable_sp,
404
    max_num_batched_tokens,
405
    cudagraph_mode,
406
407
408
    expected_max_size,
):
    ctx = nullcontext()
409
    if expected_max_size == ValidationError:
410
411
        ctx = pytest.raises(expected_max_size)

412
413
414
415
    with (
        ctx,
        patch("vllm.config.parallel.cuda_device_count_stateless", return_value=tp_size),
    ):
416
417
418
        compilation_config = CompilationConfig(
            cudagraph_capture_sizes=cudagraph_capture_sizes,
            max_cudagraph_capture_size=max_cudagraph_capture_size,
419
420
421
422
423
            pass_config=PassConfig(
                enable_sp=enable_sp,
                fuse_norm_quant=True,
                fuse_act_quant=True,
                eliminate_noops=True,
424
                sp_min_token_num=512 if enable_sp else None,
425
            ),
426
427
428
429
430
            cudagraph_mode=cudagraph_mode,
        )
        engine_args = EngineArgs(
            model="facebook/opt-125m",
            tensor_parallel_size=tp_size,
431
            max_num_seqs=min(max_num_batched_tokens, 128),
432
433
434
435
436
            max_num_batched_tokens=max_num_batched_tokens,
            compilation_config=compilation_config,
        )
        vllm_config = engine_args.create_engine_config()

437
438
439
440
        assert (
            vllm_config.compilation_config.max_cudagraph_capture_size
            == expected_max_size
        )
441
442


443
def test_cached_compilation_config(default_vllm_config):
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
    import torch
    from torch._inductor.utils import run_and_get_code

    from vllm.config import get_cached_compilation_config, set_current_vllm_config
    from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
    from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape

    dtype = torch.bfloat16
    device = torch.device("cuda:0")
    batch_size, num_qo_heads, head_size = 8, 16, 128

    # access and cache default compilation config
    # default compilation config does not contain +quant_fp8 custom op. If this is
    # used, the generated code would use inductor-generated triton kernel instead
    # of the custom op `torch.ops._C.static_scaled_fp8_quant`.
    get_cached_compilation_config()

    vllm_config = VllmConfig(
        compilation_config=CompilationConfig(
            mode=CompilationMode.VLLM_COMPILE,
            custom_ops=["+quant_fp8"],
        )
    )

    # set_current_vllm_config should clear cached compilation config and
    # use the new compilation_config in vllm_config
    with set_current_vllm_config(vllm_config):
        query_quant = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR)
        query_quant = torch.compile(query_quant)

        _q_scale = torch.tensor(1.0, dtype=torch.float32, device="cuda")
        query = torch.randn(
            batch_size, num_qo_heads * head_size, dtype=dtype, device=device
        )

        _, code = run_and_get_code(query_quant, query, _q_scale)

    code = " ".join(code)
    assert "torch.ops._C.static_scaled_fp8_quant.default(" in code
483
484


485
486
487
488
489
490
491
492
493
494
495
496
497
def _create_vllm_config_for_validation(
    compilation_config: CompilationConfig,
) -> MagicMock:
    """Helper to create a mock VllmConfig for padding validation testing."""
    mock_config = MagicMock(spec=VllmConfig)
    mock_config.compilation_config = compilation_config
    mock_config.scheduler_config = SchedulerConfig.default_factory(max_num_seqs=8)
    mock_config.parallel_config = ParallelConfig()
    mock_config.speculative_config = None
    mock_config.lora_config = None
    return mock_config


498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
def test_compile_sizes_padding_validation():
    """Test that compile_sizes with values that would be padded raises an error."""
    # cudagraph_capture_sizes=[1, 2, 4, 8] means:
    # - size 1 -> padded to 1
    # - size 2 -> padded to 2
    # - size 3 -> padded to 4
    # - size 4 -> padded to 4
    # - size 5 -> padded to 8
    # etc.
    # So compile_sizes=[3] should fail because 3 would be padded to 4

    with pytest.raises(ValueError, match="would be padded to"):
        config = CompilationConfig(
            cudagraph_capture_sizes=[1, 2, 4, 8],
            max_cudagraph_capture_size=8,
            compile_sizes=[3],
514
            cudagraph_mode=CUDAGraphMode.FULL,
515
516
        )
        config.post_init_cudagraph_sizes()
517
518
        dispatcher = CudagraphDispatcher(_create_vllm_config_for_validation(config))
        dispatcher.initialize_cudagraph_keys(CUDAGraphMode.FULL)
519
520
521
522
523
524

    with pytest.raises(ValueError, match="would be padded to"):
        config = CompilationConfig(
            cudagraph_capture_sizes=[1, 2, 4, 8],
            max_cudagraph_capture_size=8,
            compile_sizes=[5],
525
            cudagraph_mode=CUDAGraphMode.FULL,
526
527
        )
        config.post_init_cudagraph_sizes()
528
529
        dispatcher = CudagraphDispatcher(_create_vllm_config_for_validation(config))
        dispatcher.initialize_cudagraph_keys(CUDAGraphMode.FULL)
530
531
532
533
534

    config = CompilationConfig(
        cudagraph_capture_sizes=[1, 2, 4, 8],
        max_cudagraph_capture_size=8,
        compile_sizes=[1, 2, 4, 8],
535
        cudagraph_mode=CUDAGraphMode.FULL,
536
537
538
    )
    config.post_init_cudagraph_sizes()
    assert sorted(config.compile_sizes) == [1, 2, 4, 8]
539
540
    dispatcher = CudagraphDispatcher(_create_vllm_config_for_validation(config))
    dispatcher.initialize_cudagraph_keys(CUDAGraphMode.FULL)  # Should not raise
541
542
543
544
545

    config = CompilationConfig(
        cudagraph_capture_sizes=[1, 2, 4, 8],
        max_cudagraph_capture_size=8,
        compile_sizes=["cudagraph_capture_sizes"],
546
        cudagraph_mode=CUDAGraphMode.FULL,
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
    )
    config.post_init_cudagraph_sizes()
    assert sorted(config.compile_sizes) == [1, 2, 4, 8]

    # When cudagraphs are disabled (max_cudagraph_capture_size=0),
    # padding validation should be skipped
    config = CompilationConfig(
        cudagraph_capture_sizes=[],
        max_cudagraph_capture_size=0,
        compile_sizes=[3, 5, 7],  # would be invalid with cudagraphs
    )
    config.post_init_cudagraph_sizes()
    assert sorted(config.compile_sizes) == [3, 5, 7]

    # When cudagraph_mode is NONE but capture_sizes is non-empty,
    # padding validation should still be skipped
    config = CompilationConfig(
        cudagraph_capture_sizes=[1, 2, 4, 8],
        max_cudagraph_capture_size=8,
        cudagraph_mode=CUDAGraphMode.NONE,
        compile_sizes=[3, 5, 7],  # would be invalid if cudagraphs were enabled
    )
    config.post_init_cudagraph_sizes()
    assert sorted(config.compile_sizes) == [3, 5, 7]
571
572
    dispatcher = CudagraphDispatcher(_create_vllm_config_for_validation(config))
    dispatcher.initialize_cudagraph_keys(CUDAGraphMode.NONE)  # Should not raise
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614


@pytest.mark.parametrize(
    "capture_sizes, max_size, num_blocks, expected_sizes, expected_max",
    [
        # Normal capping: sizes filtered to <= num_blocks
        (
            [1, 2, 4, 8, 16, 32, 64, 128, 256, 512],
            512,
            200,
            [1, 2, 4, 8, 16, 32, 64, 128],
            128,
        ),
        # No capping needed: num_blocks >= max
        ([1, 2, 4, 8, 16], 16, 1000, [1, 2, 4, 8, 16], 16),
        # Exact boundary: num_blocks == max (no capping)
        ([1, 2, 4, 8, 16, 32], 32, 32, [1, 2, 4, 8, 16, 32], 32),
        # All sizes capped: num_blocks < smallest size
        ([8, 16, 32], 32, 4, [], 0),
        # num_blocks <= 0: early return, no change
        ([1, 2, 4], 4, 0, [1, 2, 4], 4),
    ],
)
def test_adjust_cudagraph_sizes_for_mamba_cache(
    capture_sizes, max_size, num_blocks, expected_sizes, expected_max
):
    """Test that cudagraph capture sizes are correctly capped to fit
    available Mamba cache blocks.

    See: https://github.com/vllm-project/vllm/issues/34094
    """
    config = CompilationConfig(
        cudagraph_capture_sizes=capture_sizes,
        max_cudagraph_capture_size=max_size,
        cudagraph_mode=CUDAGraphMode.NONE,
    )
    config.adjust_cudagraph_sizes_for_mamba_cache(num_blocks)
    assert config.cudagraph_capture_sizes == expected_sizes
    assert config.max_cudagraph_capture_size == expected_max
    # Invariant: last element == max_cudagraph_capture_size
    if expected_sizes:
        assert config.cudagraph_capture_sizes[-1] == config.max_cudagraph_capture_size