test_config.py 23.5 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import copy
4
from contextlib import nullcontext
5
from unittest.mock import MagicMock, patch
6

7
import pytest
8
import torch
9
from pydantic import ValidationError
10
11

from vllm.compilation.counter import compilation_counter
12
13
14
from vllm.compilation.passes.utility.fix_functionalization import (
    FixFunctionalizationPass,
)
15
16
17
18
19
20
21
from vllm.config import (
    CompilationConfig,
    CUDAGraphMode,
    ParallelConfig,
    SchedulerConfig,
    VllmConfig,
)
22
from vllm.config.compilation import CompilationMode, PassConfig
23
24
from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform
25
26
27
28
from vllm.utils.torch_utils import (
    _is_torch_equal_or_newer,
    is_torch_equal,
)
29
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
30

31
32
33
# This import automatically registers `torch.ops.silly.attention`
from . import silly_attention  # noqa: F401

34
35
DEVICE_TYPE = current_platform.device_type

36
37

def test_version():
38
    # Test the version comparison logic using the private function
39
40
41
42
43
    assert _is_torch_equal_or_newer("2.8.0.dev20250624+cu128", "2.8.0.dev")
    assert _is_torch_equal_or_newer("2.8.0a0+gitc82a174", "2.8.0.dev")
    assert _is_torch_equal_or_newer("2.8.0", "2.8.0.dev")
    assert _is_torch_equal_or_newer("2.8.1", "2.8.0.dev")
    assert not _is_torch_equal_or_newer("2.7.1", "2.8.0.dev")
44
45


46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
def test_get_raw_stream_patch():
    """Test that get_raw_stream patch is applied only for torch 2.9.0 or 2.9.1."""
    import builtins

    # Check if get_raw_stream exists in builtins
    has_patch = hasattr(builtins, "get_raw_stream")

    # Import torch to get actual version

    is_torch_2_9 = is_torch_equal("2.9.0") or is_torch_equal("2.9.1")

    if is_torch_2_9:
        # For torch 2.9.x, the patch should be applied
        assert has_patch, "get_raw_stream should be patched for torch 2.9.x"
        # Verify it's callable (it should be the _cuda_getCurrentRawStream function)
        get_raw_stream = builtins.get_raw_stream  # type: ignore[attr-defined]
        assert callable(get_raw_stream)
        # Verify it's the correct function from torch._C
        from torch._C import _cuda_getCurrentRawStream

        assert get_raw_stream is _cuda_getCurrentRawStream


69
70
71
72
73
74
75
76
77
78
79
80
81
82
def test_copy_pass():
    vllm_config = VllmConfig()
    inductor_pass = FixFunctionalizationPass(vllm_config)
    copied_inductor_pass = copy.deepcopy(inductor_pass)
    assert (
        copied_inductor_pass.compilation_config.use_inductor_graph_partition
        == vllm_config.compilation_config.use_inductor_graph_partition
    )
    assert (
        copied_inductor_pass.compilation_config.splitting_ops
        == vllm_config.compilation_config.splitting_ops
    )


83
84
85
86
87
88
89
90
def test_custom_op():
    # proper syntax
    _ = CompilationConfig(custom_ops=["+quant_fp8", "-silu_and_mul"])

    with pytest.raises(ValueError, match="Invalid syntax '"):
        _ = CompilationConfig(custom_ops=["quant_fp8"])


91
92
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@pytest.mark.forked
93
94
95
96
97
# NB: We don't test VLLM_DISABLE_COMPILE_CACHE=0 because that depends
# on the state of the cache directory on the current machine, which
# may be influenced by other tests.
@pytest.mark.parametrize("val", ["1"])
def test_VLLM_DISABLE_COMPILE_CACHE(vllm_runner, monkeypatch, val):
98
    # Disable multiprocessing so that the counter is in the same process
99
100
    monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
    monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", val)
101
102

    compilation_config = {
103
        "cudagraph_mode": CUDAGraphMode.NONE,  # speed things up a bit
104
105
    }
    with (
106
107
108
109
110
111
112
113
114
115
        compilation_counter.expect(
            num_cache_entries_updated=0, num_compiled_artifacts_saved=0
        ),
        # loading the model causes compilation (if enabled) to happen
        vllm_runner(
            "facebook/opt-125m",
            compilation_config=compilation_config,
            gpu_memory_utilization=0.4,
        ) as _,
    ):
116
117
118
        pass


119
120
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@pytest.mark.forked
121
122
123
124
125
126
127
128
129
130
131
132
@pytest.mark.parametrize(
    "cudagraph_mode,num_cudagraph_captured",
    [
        (CUDAGraphMode.NONE, 0),
        (CUDAGraphMode.FULL_DECODE_ONLY, 1),
        (CUDAGraphMode.PIECEWISE, 13),
        (CUDAGraphMode.FULL_AND_PIECEWISE, 14),
    ],
)
def test_use_cudagraphs(
    vllm_runner, monkeypatch, cudagraph_mode, num_cudagraph_captured
):
133
    # Disable multiprocessing so that the counter is in the same process
134
    monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
135
136
137

    compilation_config = {
        "cudagraph_capture_sizes": [100],
138
        "cudagraph_mode": cudagraph_mode,
139
    }
140
    num_gpu_runner_capture_triggers = 1 if cudagraph_mode != CUDAGraphMode.NONE else 0
141
    with (
142
143
        compilation_counter.expect(
            num_graphs_seen=1,
144
145
            num_gpu_runner_capture_triggers=num_gpu_runner_capture_triggers,
            num_cudagraph_captured=num_cudagraph_captured,
146
147
148
149
150
151
152
153
        ),
        # loading the model causes compilation (if enabled) to happen
        vllm_runner(
            "facebook/opt-125m",
            compilation_config=compilation_config,
            gpu_memory_utilization=0.4,
        ) as _,
    ):
154
        pass
155
156
157
158


# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@pytest.mark.forked
159
def test_stock_torch_compile(vllm_runner, monkeypatch):
160
    # Disable multiprocessing so that the counter is in the same process
161
    monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
162
163

    with (
164
        compilation_counter.expect(stock_torch_compile_count=1),
165
166
167
        # loading the model causes compilation (if enabled) to happen
        vllm_runner(
            "facebook/opt-125m",
168
            compilation_config={"mode": CompilationMode.STOCK_TORCH_COMPILE},
169
170
171
            gpu_memory_utilization=0.4,
        ) as _,
    ):
172
173
174
175
176
177
178
        pass


# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@pytest.mark.forked
def test_no_compilation(vllm_runner, monkeypatch):
    # Disable multiprocessing so that the counter is in the same process
179
    monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
180
    with (
181
        compilation_counter.expect(num_graphs_seen=0, stock_torch_compile_count=0),
182
183
184
        # loading the model causes compilation (if enabled) to happen
        vllm_runner(
            "facebook/opt-125m",
185
            compilation_config={"mode": CompilationMode.NONE},
186
187
188
            gpu_memory_utilization=0.4,
        ) as _,
    ):
189
190
191
192
193
194
195
        pass


# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@pytest.mark.forked
def test_enforce_eager(vllm_runner, monkeypatch):
    # Disable multiprocessing so that the counter is in the same process
196
    monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
197
198

    with (
199
        compilation_counter.expect(num_graphs_seen=0, stock_torch_compile_count=0),
200
201
202
203
204
        # loading the model causes compilation (if enabled) to happen
        vllm_runner(
            "facebook/opt-125m", enforce_eager=True, gpu_memory_utilization=0.4
        ) as _,
    ):
205
        pass
206
207


208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
@pytest.mark.forked
def test_torch_compile_disable(vllm_runner, monkeypatch):
    monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
    monkeypatch.setenv("TORCH_COMPILE_DISABLE", "1")
    monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")

    with (
        compilation_counter.expect(num_graphs_seen=0, stock_torch_compile_count=0),
        vllm_runner(
            "facebook/opt-125m",
            gpu_memory_utilization=0.4,
        ) as _,
    ):
        pass


224
225
226
def test_splitting_ops_dynamic():
    # Default config
    config = VllmConfig()
227
228
    # Default V1 config leaves cudagraph mode unset; splitting ops are only
    # populated when the engine decides to use piecewise compilation.
229
230
    assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE
    assert config.compilation_config.splitting_ops_contain_attention()
231
232

    # When use_inductor_graph_partition=True
233
234
235
236
    config = VllmConfig(
        compilation_config=CompilationConfig(
            mode=CompilationMode.VLLM_COMPILE,
            use_inductor_graph_partition=True,
237
            splitting_ops=["vllm::unified_attention_with_output"],
238
        )
239
240
241
    )
    # with inductor partition we use splitting_ops directly for
    # partition rules
242
243
244
    assert config.compilation_config.splitting_ops == [
        "vllm::unified_attention_with_output"
    ]
245

246
    # When attn_fusion pass enabled.
247
248
    config = VllmConfig(
        compilation_config=CompilationConfig(
249
            mode=CompilationMode.VLLM_COMPILE,
250
            pass_config=PassConfig(fuse_attn_quant=True, eliminate_noops=True),
251
252
253
254
            custom_ops=["+quant_fp8"],
            cudagraph_mode=CUDAGraphMode.PIECEWISE,
        )
    )
255
256
257
    assert config.compilation_config.splitting_ops == []
    # cudagraph mode also fall back to FULL
    assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL
258

259
260
261
    # splitting_ops can not contain attention ops when attn_fusion
    # pass enabled.
    with pytest.raises(ValidationError):
262
263
        config = VllmConfig(
            compilation_config=CompilationConfig(
264
                mode=CompilationMode.VLLM_COMPILE,
265
                pass_config=PassConfig(fuse_attn_quant=True, eliminate_noops=True),
266
267
                custom_ops=["+quant_fp8"],
                cudagraph_mode=CUDAGraphMode.PIECEWISE,
268
269
                # work around for accessing all attntion ops
                splitting_ops=CompilationConfig()._attention_ops,
270
271
            )
        )
272
273
274
275
276
277

    # When both use_inductor_graph_partition and attn_fusion pass enabled.
    config = VllmConfig(
        compilation_config=CompilationConfig(
            mode=CompilationMode.VLLM_COMPILE,
            use_inductor_graph_partition=True,
278
            pass_config=PassConfig(fuse_attn_quant=True, eliminate_noops=True),
279
280
281
282
283
284
285
            custom_ops=["+quant_fp8"],
            cudagraph_mode=CUDAGraphMode.PIECEWISE,
        )
    )
    # With inductor graph partition, attn_fusion and splitting_ops
    # work together. Default splitting_ops include attention ops.
    assert config.compilation_config.splitting_ops_contain_attention()
286
    # fuse_attn_quant is directly supported under
287
288
289
    # use_inductor_graph_partition=True, and cudagraph_mode
    # is unchanged.
    assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE
290
291


292
293
294
295
296
297
298
299
300
301
302
303
def test_moe_splitting_ops_deepep_ht_inductor_partition():
    # Inductor partition case: user-provided splitting_ops should be
    # preserved and MoE ops should be appended for DeepEP HT with dp>1.
    config = VllmConfig(
        parallel_config=ParallelConfig(
            all2all_backend="deepep_high_throughput",
            data_parallel_size=8,
        ),
        compilation_config=CompilationConfig(
            mode=CompilationMode.VLLM_COMPILE,
            use_inductor_graph_partition=True,
            splitting_ops=[
304
                "vllm::unified_attention_with_output",
305
306
307
308
309
310
311
                "vllm::moe_forward",
                "vllm::moe_forward_shared",
            ],
        ),
    )
    splitting_ops = config.compilation_config.splitting_ops
    assert splitting_ops == [
312
        "vllm::unified_attention_with_output",
313
314
315
316
317
        "vllm::moe_forward",
        "vllm::moe_forward_shared",
    ]


318
def test_should_split():
319
320
    import torch

321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
    from vllm.compilation.partition_rules import should_split

    graph = torch.fx.Graph()
    node = torch.fx.Node(
        graph=graph,
        name="dummy_node",
        op="call_function",
        target=torch.ops.aten.add.default,
        args=(),
        kwargs={},
    )

    # supports OpOverloadPacket
    splitting_ops = ["aten::add"]
    assert should_split(node, splitting_ops)

    # supports OpOverload
    splitting_ops = ["aten::add.default"]
    assert should_split(node, splitting_ops)

    # supports OpOverload
    splitting_ops = ["aten::add.Tensor"]
    assert not should_split(node, splitting_ops)

    q, k, v, out = [torch.randn(1)] * 4

    # supports custom ops as OpOverloadPacket
    node = torch.fx.Node(
        graph=graph,
        name="dummy_node",
        op="call_function",
        target=torch.ops.silly.attention,
        args=(q, k, v, out),
        kwargs={},
    )

    splitting_ops = ["silly::attention"]
    assert should_split(node, splitting_ops)

    # supports custom ops as OpOverload
    node = torch.fx.Node(
        graph=graph,
        name="dummy_node",
        op="call_function",
        target=torch.ops.silly.attention.default,
        args=(q, k, v, out),
        kwargs={},
    )

    splitting_ops = ["silly::attention"]
    assert should_split(node, splitting_ops)

    splitting_ops = ["silly::attention.default"]
    assert should_split(node, splitting_ops)
375
376
377
378
379
380
381
382
383
384
385


@pytest.mark.skipif(
    not current_platform.support_static_graph_mode(),
    reason="Skip if not cudagraph mode supported",
)
@pytest.mark.parametrize(
    (
        "cudagraph_capture_sizes",
        "max_cudagraph_capture_size",
        "tp_size",
386
        "enable_sp",
387
        "max_num_batched_tokens",
388
        "cudagraph_mode",
389
390
391
        "expected_max_size",
    ),
    [
392
393
394
395
396
397
398
399
400
401
402
403
404
405
        (None, None, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 256),
        ([1, 2, 4], 4, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 4),
        (
            [1, 2, 4],
            8,
            1,
            False,
            2048,
            CUDAGraphMode.FULL_AND_PIECEWISE,
            ValidationError,
        ),
        ([1, 256], None, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 256),
        ([], None, 1, False, 2048, CUDAGraphMode.NONE, 0),
        (None, 0, 1, False, 2048, CUDAGraphMode.NONE, 0),
406
        # truncated to nearest multiple of 8 or 16
407
408
409
410
411
412
413
        (None, 257, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 256),
        # max from list
        ([1, 2, 4, 15], None, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 15),
        # filtered out 15 due to SP
        ([1, 2, 4, 15], None, 2, True, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 4),
        # limited by the max_tokens
        ([1, 2, 4, 15], None, 1, False, 8, CUDAGraphMode.FULL_AND_PIECEWISE, 4),
414
        # the list should contain at least 1 element when use cudagraph
415
        ([], None, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, ValidationError),
416
        # the max capturing size should be >= 1 when use cudagraph
417
        (None, 0, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, ValidationError),
418
419
420
421
422
423
    ],
)
def test_cudagraph_sizes_post_init(
    cudagraph_capture_sizes,
    max_cudagraph_capture_size,
    tp_size,
424
    enable_sp,
425
    max_num_batched_tokens,
426
    cudagraph_mode,
427
428
429
    expected_max_size,
):
    ctx = nullcontext()
430
    if expected_max_size == ValidationError:
431
432
        ctx = pytest.raises(expected_max_size)

433
434
    with (
        ctx,
435
        patch.object(current_platform, "device_count", return_value=tp_size),
436
    ):
437
438
439
440
441
        kwargs = {}
        if cudagraph_capture_sizes is not None:
            kwargs["cudagraph_capture_sizes"] = cudagraph_capture_sizes
        if max_cudagraph_capture_size is not None:
            kwargs["max_cudagraph_capture_size"] = max_cudagraph_capture_size
442
        compilation_config = CompilationConfig(
443
444
445
446
447
            pass_config=PassConfig(
                enable_sp=enable_sp,
                fuse_norm_quant=True,
                fuse_act_quant=True,
                eliminate_noops=True,
448
                sp_min_token_num=512 if enable_sp else None,
449
            ),
450
            cudagraph_mode=cudagraph_mode,
451
            **kwargs,
452
453
454
455
        )
        engine_args = EngineArgs(
            model="facebook/opt-125m",
            tensor_parallel_size=tp_size,
456
            max_num_seqs=min(max_num_batched_tokens, 128),
457
458
459
460
461
            max_num_batched_tokens=max_num_batched_tokens,
            compilation_config=compilation_config,
        )
        vllm_config = engine_args.create_engine_config()

462
463
464
465
        assert (
            vllm_config.compilation_config.max_cudagraph_capture_size
            == expected_max_size
        )
466
467


468
def test_cached_compilation_config(default_vllm_config):
469
470
471
472
473
474
475
476
    import torch
    from torch._inductor.utils import run_and_get_code

    from vllm.config import get_cached_compilation_config, set_current_vllm_config
    from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
    from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape

    dtype = torch.bfloat16
477
    device = torch.device(f"{DEVICE_TYPE}:0")
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
    batch_size, num_qo_heads, head_size = 8, 16, 128

    # access and cache default compilation config
    # default compilation config does not contain +quant_fp8 custom op. If this is
    # used, the generated code would use inductor-generated triton kernel instead
    # of the custom op `torch.ops._C.static_scaled_fp8_quant`.
    get_cached_compilation_config()

    vllm_config = VllmConfig(
        compilation_config=CompilationConfig(
            mode=CompilationMode.VLLM_COMPILE,
            custom_ops=["+quant_fp8"],
        )
    )

    # set_current_vllm_config should clear cached compilation config and
    # use the new compilation_config in vllm_config
    with set_current_vllm_config(vllm_config):
        query_quant = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR)
        query_quant = torch.compile(query_quant)

499
        _q_scale = torch.tensor(1.0, dtype=torch.float32, device=DEVICE_TYPE)
500
501
502
503
504
505
506
507
        query = torch.randn(
            batch_size, num_qo_heads * head_size, dtype=dtype, device=device
        )

        _, code = run_and_get_code(query_quant, query, _q_scale)

    code = " ".join(code)
    assert "torch.ops._C.static_scaled_fp8_quant.default(" in code
508
509


510
511
512
513
514
515
516
517
518
519
520
521
522
def _create_vllm_config_for_validation(
    compilation_config: CompilationConfig,
) -> MagicMock:
    """Helper to create a mock VllmConfig for padding validation testing."""
    mock_config = MagicMock(spec=VllmConfig)
    mock_config.compilation_config = compilation_config
    mock_config.scheduler_config = SchedulerConfig.default_factory(max_num_seqs=8)
    mock_config.parallel_config = ParallelConfig()
    mock_config.speculative_config = None
    mock_config.lora_config = None
    return mock_config


523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
def test_compile_sizes_padding_validation():
    """Test that compile_sizes with values that would be padded raises an error."""
    # cudagraph_capture_sizes=[1, 2, 4, 8] means:
    # - size 1 -> padded to 1
    # - size 2 -> padded to 2
    # - size 3 -> padded to 4
    # - size 4 -> padded to 4
    # - size 5 -> padded to 8
    # etc.
    # So compile_sizes=[3] should fail because 3 would be padded to 4

    with pytest.raises(ValueError, match="would be padded to"):
        config = CompilationConfig(
            cudagraph_capture_sizes=[1, 2, 4, 8],
            max_cudagraph_capture_size=8,
            compile_sizes=[3],
539
            cudagraph_mode=CUDAGraphMode.FULL,
540
541
        )
        config.post_init_cudagraph_sizes()
542
543
        dispatcher = CudagraphDispatcher(_create_vllm_config_for_validation(config))
        dispatcher.initialize_cudagraph_keys(CUDAGraphMode.FULL)
544
545
546
547
548
549

    with pytest.raises(ValueError, match="would be padded to"):
        config = CompilationConfig(
            cudagraph_capture_sizes=[1, 2, 4, 8],
            max_cudagraph_capture_size=8,
            compile_sizes=[5],
550
            cudagraph_mode=CUDAGraphMode.FULL,
551
552
        )
        config.post_init_cudagraph_sizes()
553
554
        dispatcher = CudagraphDispatcher(_create_vllm_config_for_validation(config))
        dispatcher.initialize_cudagraph_keys(CUDAGraphMode.FULL)
555
556
557
558
559

    config = CompilationConfig(
        cudagraph_capture_sizes=[1, 2, 4, 8],
        max_cudagraph_capture_size=8,
        compile_sizes=[1, 2, 4, 8],
560
        cudagraph_mode=CUDAGraphMode.FULL,
561
562
563
    )
    config.post_init_cudagraph_sizes()
    assert sorted(config.compile_sizes) == [1, 2, 4, 8]
564
565
    dispatcher = CudagraphDispatcher(_create_vllm_config_for_validation(config))
    dispatcher.initialize_cudagraph_keys(CUDAGraphMode.FULL)  # Should not raise
566
567
568
569
570

    config = CompilationConfig(
        cudagraph_capture_sizes=[1, 2, 4, 8],
        max_cudagraph_capture_size=8,
        compile_sizes=["cudagraph_capture_sizes"],
571
        cudagraph_mode=CUDAGraphMode.FULL,
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
    )
    config.post_init_cudagraph_sizes()
    assert sorted(config.compile_sizes) == [1, 2, 4, 8]

    # When cudagraphs are disabled (max_cudagraph_capture_size=0),
    # padding validation should be skipped
    config = CompilationConfig(
        cudagraph_capture_sizes=[],
        max_cudagraph_capture_size=0,
        compile_sizes=[3, 5, 7],  # would be invalid with cudagraphs
    )
    config.post_init_cudagraph_sizes()
    assert sorted(config.compile_sizes) == [3, 5, 7]

    # When cudagraph_mode is NONE but capture_sizes is non-empty,
    # padding validation should still be skipped
    config = CompilationConfig(
        cudagraph_capture_sizes=[1, 2, 4, 8],
        max_cudagraph_capture_size=8,
        cudagraph_mode=CUDAGraphMode.NONE,
        compile_sizes=[3, 5, 7],  # would be invalid if cudagraphs were enabled
    )
    config.post_init_cudagraph_sizes()
    assert sorted(config.compile_sizes) == [3, 5, 7]
596
597
    dispatcher = CudagraphDispatcher(_create_vllm_config_for_validation(config))
    dispatcher.initialize_cudagraph_keys(CUDAGraphMode.NONE)  # Should not raise
598
599


600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
def test_inductor_asserts_default_disabled(monkeypatch):
    """Test that inductor runtime asserts are disabled by default
    (INFO logging level) on torch < 2.12."""
    monkeypatch.setenv("VLLM_LOGGING_LEVEL", "INFO")

    import importlib

    import vllm.envs

    importlib.reload(vllm.envs)

    config = CompilationConfig()
    if not _is_torch_equal_or_newer(torch.__version__, "2.12.0.dev"):
        assert config.inductor_compile_config.get("size_asserts") is False
        assert config.inductor_compile_config.get("alignment_asserts") is False
        assert config.inductor_compile_config.get("scalar_asserts") is False


def test_inductor_asserts_enabled_in_debug(monkeypatch):
    """Test that VLLM_LOGGING_LEVEL=DEBUG enables inductor runtime asserts
    on torch < 2.12."""
    monkeypatch.setenv("VLLM_LOGGING_LEVEL", "DEBUG")

    import importlib

    import vllm.envs

    importlib.reload(vllm.envs)

    config = CompilationConfig()
    if not _is_torch_equal_or_newer(torch.__version__, "2.12.0.dev"):
        assert config.inductor_compile_config.get("size_asserts") is True
        assert config.inductor_compile_config.get("alignment_asserts") is True
        assert config.inductor_compile_config.get("scalar_asserts") is True


636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
def test_get_inductor_factors_includes_configs():
    """Changing inductor or functorch config must change the cache key factors."""
    from torch._functorch import config as functorch_config
    from torch._inductor import config as inductor_config

    from vllm.compilation.compiler_interface import get_inductor_factors

    baseline = get_inductor_factors()

    with inductor_config.patch("max_autotune", not inductor_config.max_autotune):
        patched = get_inductor_factors()
    assert baseline != patched, "inductor config change was not reflected"

    with functorch_config.patch("donated_buffer", not functorch_config.donated_buffer):
        patched = get_inductor_factors()
    assert baseline != patched, "functorch config change was not reflected"


654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
def test_inductor_asserts_user_override(monkeypatch):
    """Test that explicit inductor_compile_config overrides the
    debug-logging default."""
    monkeypatch.setenv("VLLM_LOGGING_LEVEL", "INFO")

    import importlib

    import vllm.envs

    importlib.reload(vllm.envs)

    config = CompilationConfig(
        inductor_compile_config={"size_asserts": True},
    )
    assert config.inductor_compile_config.get("size_asserts") is True
    if not _is_torch_equal_or_newer(torch.__version__, "2.12.0.dev"):
        assert config.inductor_compile_config.get("alignment_asserts") is False