test_config.py 17.7 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import copy
4
import logging
5
from contextlib import nullcontext
6
from unittest.mock import patch
7

8
import pytest
9
from pydantic import ValidationError
10
11

from vllm.compilation.counter import compilation_counter
12
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
13
from vllm.config import CompilationConfig, CUDAGraphMode, ParallelConfig, VllmConfig
14
from vllm.config.compilation import CompilationMode, PassConfig
15
from vllm.engine.arg_utils import EngineArgs
16
from vllm.logger import _print_warning_once
17
from vllm.platforms import current_platform
18
from vllm.utils.torch_utils import _is_torch_equal_or_newer
19

20
21
22
# This import automatically registers `torch.ops.silly.attention`
from . import silly_attention  # noqa: F401

23
24

def test_version():
25
    # Test the version comparison logic using the private function
26
27
28
29
30
    assert _is_torch_equal_or_newer("2.8.0.dev20250624+cu128", "2.8.0.dev")
    assert _is_torch_equal_or_newer("2.8.0a0+gitc82a174", "2.8.0.dev")
    assert _is_torch_equal_or_newer("2.8.0", "2.8.0.dev")
    assert _is_torch_equal_or_newer("2.8.1", "2.8.0.dev")
    assert not _is_torch_equal_or_newer("2.7.1", "2.8.0.dev")
31
32


33
34
35
36
37
38
39
40
41
42
43
44
45
46
def test_copy_pass():
    vllm_config = VllmConfig()
    inductor_pass = FixFunctionalizationPass(vllm_config)
    copied_inductor_pass = copy.deepcopy(inductor_pass)
    assert (
        copied_inductor_pass.compilation_config.use_inductor_graph_partition
        == vllm_config.compilation_config.use_inductor_graph_partition
    )
    assert (
        copied_inductor_pass.compilation_config.splitting_ops
        == vllm_config.compilation_config.splitting_ops
    )


47
48
49
50
51
52
53
54
def test_custom_op():
    # proper syntax
    _ = CompilationConfig(custom_ops=["+quant_fp8", "-silu_and_mul"])

    with pytest.raises(ValueError, match="Invalid syntax '"):
        _ = CompilationConfig(custom_ops=["quant_fp8"])


55
56
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@pytest.mark.forked
57
58
59
60
61
# NB: We don't test VLLM_DISABLE_COMPILE_CACHE=0 because that depends
# on the state of the cache directory on the current machine, which
# may be influenced by other tests.
@pytest.mark.parametrize("val", ["1"])
def test_VLLM_DISABLE_COMPILE_CACHE(vllm_runner, monkeypatch, val):
62
    # Disable multiprocessing so that the counter is in the same process
63
64
    monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
    monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", val)
65
66

    compilation_config = {
67
        "cudagraph_mode": CUDAGraphMode.NONE,  # speed things up a bit
68
69
    }
    with (
70
71
72
73
74
75
76
77
78
79
        compilation_counter.expect(
            num_cache_entries_updated=0, num_compiled_artifacts_saved=0
        ),
        # loading the model causes compilation (if enabled) to happen
        vllm_runner(
            "facebook/opt-125m",
            compilation_config=compilation_config,
            gpu_memory_utilization=0.4,
        ) as _,
    ):
80
81
82
        pass


83
84
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@pytest.mark.forked
85
86
87
88
89
90
91
92
93
94
95
96
@pytest.mark.parametrize(
    "cudagraph_mode,num_cudagraph_captured",
    [
        (CUDAGraphMode.NONE, 0),
        (CUDAGraphMode.FULL_DECODE_ONLY, 1),
        (CUDAGraphMode.PIECEWISE, 13),
        (CUDAGraphMode.FULL_AND_PIECEWISE, 14),
    ],
)
def test_use_cudagraphs(
    vllm_runner, monkeypatch, cudagraph_mode, num_cudagraph_captured
):
97
    # Disable multiprocessing so that the counter is in the same process
98
    monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
99
100
101

    compilation_config = {
        "cudagraph_capture_sizes": [100],
102
        "cudagraph_mode": cudagraph_mode,
103
    }
104
    num_gpu_runner_capture_triggers = 1 if cudagraph_mode != CUDAGraphMode.NONE else 0
105
    with (
106
107
        compilation_counter.expect(
            num_graphs_seen=1,
108
109
            num_gpu_runner_capture_triggers=num_gpu_runner_capture_triggers,
            num_cudagraph_captured=num_cudagraph_captured,
110
111
112
113
114
115
116
117
        ),
        # loading the model causes compilation (if enabled) to happen
        vllm_runner(
            "facebook/opt-125m",
            compilation_config=compilation_config,
            gpu_memory_utilization=0.4,
        ) as _,
    ):
118
        pass
119
120
121
122


# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@pytest.mark.forked
123
def test_stock_torch_compile(vllm_runner, monkeypatch):
124
    # Disable multiprocessing so that the counter is in the same process
125
    monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
126
127

    with (
128
        compilation_counter.expect(stock_torch_compile_count=1),
129
130
131
        # loading the model causes compilation (if enabled) to happen
        vllm_runner(
            "facebook/opt-125m",
132
            compilation_config={"mode": CompilationMode.STOCK_TORCH_COMPILE},
133
134
135
            gpu_memory_utilization=0.4,
        ) as _,
    ):
136
137
138
139
140
141
142
        pass


# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@pytest.mark.forked
def test_no_compilation(vllm_runner, monkeypatch):
    # Disable multiprocessing so that the counter is in the same process
143
    monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
144
    with (
145
        compilation_counter.expect(num_graphs_seen=0, stock_torch_compile_count=0),
146
147
148
        # loading the model causes compilation (if enabled) to happen
        vllm_runner(
            "facebook/opt-125m",
149
            compilation_config={"mode": CompilationMode.NONE},
150
151
152
            gpu_memory_utilization=0.4,
        ) as _,
    ):
153
154
155
156
157
158
159
        pass


# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@pytest.mark.forked
def test_enforce_eager(vllm_runner, monkeypatch):
    # Disable multiprocessing so that the counter is in the same process
160
    monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
161
162

    with (
163
        compilation_counter.expect(num_graphs_seen=0, stock_torch_compile_count=0),
164
165
166
167
168
        # loading the model causes compilation (if enabled) to happen
        vllm_runner(
            "facebook/opt-125m", enforce_eager=True, gpu_memory_utilization=0.4
        ) as _,
    ):
169
        pass
170
171
172
173
174


def test_splitting_ops_dynamic():
    # Default config
    config = VllmConfig()
175
176
    # Default V1 config leaves cudagraph mode unset; splitting ops are only
    # populated when the engine decides to use piecewise compilation.
177
178
    assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE
    assert config.compilation_config.splitting_ops_contain_attention()
179
180

    # When use_inductor_graph_partition=True
181
182
183
184
185
    config = VllmConfig(
        compilation_config=CompilationConfig(
            mode=CompilationMode.VLLM_COMPILE,
            use_inductor_graph_partition=True,
            splitting_ops=["vllm::unified_attention"],
186
        )
187
188
189
190
    )
    # with inductor partition we use splitting_ops directly for
    # partition rules
    assert config.compilation_config.splitting_ops == ["vllm::unified_attention"]
191

192
    # When attn_fusion pass enabled.
193
194
    config = VllmConfig(
        compilation_config=CompilationConfig(
195
            mode=CompilationMode.VLLM_COMPILE,
196
            pass_config=PassConfig(fuse_attn_quant=True, eliminate_noops=True),
197
198
199
200
            custom_ops=["+quant_fp8"],
            cudagraph_mode=CUDAGraphMode.PIECEWISE,
        )
    )
201
202
203
    assert config.compilation_config.splitting_ops == []
    # cudagraph mode also fall back to FULL
    assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL
204

205
206
207
    # splitting_ops can not contain attention ops when attn_fusion
    # pass enabled.
    with pytest.raises(ValidationError):
208
209
        config = VllmConfig(
            compilation_config=CompilationConfig(
210
                mode=CompilationMode.VLLM_COMPILE,
211
                pass_config=PassConfig(fuse_attn_quant=True, eliminate_noops=True),
212
213
                custom_ops=["+quant_fp8"],
                cudagraph_mode=CUDAGraphMode.PIECEWISE,
214
215
                # work around for accessing all attntion ops
                splitting_ops=CompilationConfig()._attention_ops,
216
217
            )
        )
218
219
220
221
222
223

    # When both use_inductor_graph_partition and attn_fusion pass enabled.
    config = VllmConfig(
        compilation_config=CompilationConfig(
            mode=CompilationMode.VLLM_COMPILE,
            use_inductor_graph_partition=True,
224
            pass_config=PassConfig(fuse_attn_quant=True, eliminate_noops=True),
225
226
227
228
229
230
231
            custom_ops=["+quant_fp8"],
            cudagraph_mode=CUDAGraphMode.PIECEWISE,
        )
    )
    # With inductor graph partition, attn_fusion and splitting_ops
    # work together. Default splitting_ops include attention ops.
    assert config.compilation_config.splitting_ops_contain_attention()
232
    # fuse_attn_quant is directly supported under
233
234
235
    # use_inductor_graph_partition=True, and cudagraph_mode
    # is unchanged.
    assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE
236
237


238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
def test_moe_splitting_ops_deepep_ht_piecewise():
    # Non-inductor, non-attn-fusion case: DeepEP HT with dp>1
    # should add MoE ops to splitting_ops on top of attention ops.
    config = VllmConfig(
        parallel_config=ParallelConfig(
            all2all_backend="deepep_high_throughput",
            data_parallel_size=8,
        ),
        compilation_config=CompilationConfig(
            mode=CompilationMode.VLLM_COMPILE,
        ),
    )
    splitting_ops = config.compilation_config.splitting_ops
    assert splitting_ops is not None
    assert "vllm::moe_forward" in splitting_ops
    assert "vllm::moe_forward_shared" in splitting_ops


def test_moe_splitting_ops_deepep_ht_inductor_partition():
    # Inductor partition case: user-provided splitting_ops should be
    # preserved and MoE ops should be appended for DeepEP HT with dp>1.
    config = VllmConfig(
        parallel_config=ParallelConfig(
            all2all_backend="deepep_high_throughput",
            data_parallel_size=8,
        ),
        compilation_config=CompilationConfig(
            mode=CompilationMode.VLLM_COMPILE,
            use_inductor_graph_partition=True,
            splitting_ops=[
                "vllm::unified_attention",
                "vllm::moe_forward",
                "vllm::moe_forward_shared",
            ],
        ),
    )
    splitting_ops = config.compilation_config.splitting_ops
    assert splitting_ops == [
        "vllm::unified_attention",
        "vllm::moe_forward",
        "vllm::moe_forward_shared",
    ]


def test_moe_splitting_ops_deepep_ht_attn_fusion_no_inductor():
    # Pure attn-fusion case without inductor partition: even with
    # DeepEP HT and dp>1, we should not re-enable piecewise compilation
    # or add MoE ops into splitting_ops.
    config = VllmConfig(
        parallel_config=ParallelConfig(
            all2all_backend="deepep_high_throughput",
            data_parallel_size=8,
        ),
        compilation_config=CompilationConfig(
            mode=CompilationMode.VLLM_COMPILE,
            pass_config={"enable_attn_fusion": True, "enable_noop": True},
            custom_ops=["+quant_fp8"],
            cudagraph_mode=CUDAGraphMode.PIECEWISE,
        ),
    )
    assert config.compilation_config.splitting_ops == []
    assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL


302
def test_should_split():
303
304
    import torch

305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
    from vllm.compilation.partition_rules import should_split

    graph = torch.fx.Graph()
    node = torch.fx.Node(
        graph=graph,
        name="dummy_node",
        op="call_function",
        target=torch.ops.aten.add.default,
        args=(),
        kwargs={},
    )

    # supports OpOverloadPacket
    splitting_ops = ["aten::add"]
    assert should_split(node, splitting_ops)

    # supports OpOverload
    splitting_ops = ["aten::add.default"]
    assert should_split(node, splitting_ops)

    # supports OpOverload
    splitting_ops = ["aten::add.Tensor"]
    assert not should_split(node, splitting_ops)

    q, k, v, out = [torch.randn(1)] * 4

    # supports custom ops as OpOverloadPacket
    node = torch.fx.Node(
        graph=graph,
        name="dummy_node",
        op="call_function",
        target=torch.ops.silly.attention,
        args=(q, k, v, out),
        kwargs={},
    )

    splitting_ops = ["silly::attention"]
    assert should_split(node, splitting_ops)

    # supports custom ops as OpOverload
    node = torch.fx.Node(
        graph=graph,
        name="dummy_node",
        op="call_function",
        target=torch.ops.silly.attention.default,
        args=(q, k, v, out),
        kwargs={},
    )

    splitting_ops = ["silly::attention"]
    assert should_split(node, splitting_ops)

    splitting_ops = ["silly::attention.default"]
    assert should_split(node, splitting_ops)
359
360
361
362
363
364
365
366
367
368
369


@pytest.mark.skipif(
    not current_platform.support_static_graph_mode(),
    reason="Skip if not cudagraph mode supported",
)
@pytest.mark.parametrize(
    (
        "cudagraph_capture_sizes",
        "max_cudagraph_capture_size",
        "tp_size",
370
        "enable_sp",
371
        "max_num_batched_tokens",
372
        "cudagraph_mode",
373
374
375
        "expected_max_size",
    ),
    [
376
377
378
379
380
381
382
383
384
385
386
387
388
389
        (None, None, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 256),
        ([1, 2, 4], 4, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 4),
        (
            [1, 2, 4],
            8,
            1,
            False,
            2048,
            CUDAGraphMode.FULL_AND_PIECEWISE,
            ValidationError,
        ),
        ([1, 256], None, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 256),
        ([], None, 1, False, 2048, CUDAGraphMode.NONE, 0),
        (None, 0, 1, False, 2048, CUDAGraphMode.NONE, 0),
390
        # truncated to nearest multiple of 8 or 16
391
392
393
394
395
396
397
        (None, 257, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 256),
        # max from list
        ([1, 2, 4, 15], None, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 15),
        # filtered out 15 due to SP
        ([1, 2, 4, 15], None, 2, True, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 4),
        # limited by the max_tokens
        ([1, 2, 4, 15], None, 1, False, 8, CUDAGraphMode.FULL_AND_PIECEWISE, 4),
398
        # the list should contain at least 1 element when use cudagraph
399
        ([], None, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, ValidationError),
400
        # the max capturing size should be >= 1 when use cudagraph
401
        (None, 0, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, ValidationError),
402
403
404
405
406
407
    ],
)
def test_cudagraph_sizes_post_init(
    cudagraph_capture_sizes,
    max_cudagraph_capture_size,
    tp_size,
408
    enable_sp,
409
    max_num_batched_tokens,
410
    cudagraph_mode,
411
412
413
    expected_max_size,
):
    ctx = nullcontext()
414
    if expected_max_size == ValidationError:
415
416
        ctx = pytest.raises(expected_max_size)

417
418
419
420
    with (
        ctx,
        patch("vllm.config.parallel.cuda_device_count_stateless", return_value=tp_size),
    ):
421
422
423
        compilation_config = CompilationConfig(
            cudagraph_capture_sizes=cudagraph_capture_sizes,
            max_cudagraph_capture_size=max_cudagraph_capture_size,
424
425
426
427
428
429
            pass_config=PassConfig(
                enable_sp=enable_sp,
                fuse_norm_quant=True,
                fuse_act_quant=True,
                eliminate_noops=True,
            ),
430
431
432
433
434
            cudagraph_mode=cudagraph_mode,
        )
        engine_args = EngineArgs(
            model="facebook/opt-125m",
            tensor_parallel_size=tp_size,
435
            max_num_seqs=min(max_num_batched_tokens, 128),
436
437
438
439
440
            max_num_batched_tokens=max_num_batched_tokens,
            compilation_config=compilation_config,
        )
        vllm_config = engine_args.create_engine_config()

441
442
443
444
        assert (
            vllm_config.compilation_config.max_cudagraph_capture_size
            == expected_max_size
        )
445
446
447
448
449
450
451
452
453
454
455
456
457
458


def test_pass_config_deprecation(caplog_vllm):
    caplog_vllm.set_level(logging.WARNING)

    # Clear cache to ensure warnings are re-issued
    _print_warning_once.cache_clear()

    # Test enable_fusion -> fuse_norm_quant, fuse_act_quant
    caplog_vllm.clear()
    config = PassConfig(enable_fusion=True)
    assert "enable_fusion is deprecated" in caplog_vllm.text
    assert config.fuse_norm_quant is True
    assert config.fuse_act_quant is True
459
    assert config.enable_fusion is True
460
461
462
463
464
465

    # Test enable_attn_fusion -> fuse_attn_quant
    caplog_vllm.clear()
    config = PassConfig(enable_attn_fusion=True)
    assert "enable_attn_fusion is deprecated" in caplog_vllm.text
    assert config.fuse_attn_quant is True
466
    assert config.enable_attn_fusion is True
467
468
469
470
471
472

    # Test enable_noop -> eliminate_noops
    caplog_vllm.clear()
    config = PassConfig(enable_noop=True)
    assert "enable_noop is deprecated" in caplog_vllm.text
    assert config.eliminate_noops is True
473
    assert config.enable_noop is True
474
475
476
477
478
479

    # Test enable_sequence_parallelism -> enable_sp
    caplog_vllm.clear()
    config = PassConfig(enable_sequence_parallelism=True)
    assert "enable_sequence_parallelism is deprecated" in caplog_vllm.text
    assert config.enable_sp is True
480
    assert config.enable_sequence_parallelism is True
481
482
483
484
485
486

    # Test enable_async_tp -> fuse_gemm_comms
    caplog_vllm.clear()
    config = PassConfig(enable_async_tp=True)
    assert "enable_async_tp is deprecated" in caplog_vllm.text
    assert config.fuse_gemm_comms is True
487
    assert config.enable_async_tp is True
488
489
490
491
492
493

    # Test enable_fi_allreduce_fusion -> fuse_allreduce_rms
    caplog_vllm.clear()
    config = PassConfig(enable_fi_allreduce_fusion=True)
    assert "enable_fi_allreduce_fusion is deprecated" in caplog_vllm.text
    assert config.fuse_allreduce_rms is True
494
495
496
497
498
499
500
501
502
503
    assert config.enable_fi_allreduce_fusion is True

    # Test hash consistency
    config_old = PassConfig(enable_fusion=True)
    config_new = PassConfig(fuse_norm_quant=True, fuse_act_quant=True)
    assert config_old.compute_hash() == config_new.compute_hash()

    config_old = PassConfig(enable_async_tp=True)
    config_new = PassConfig(fuse_gemm_comms=True)
    assert config_old.compute_hash() == config_new.compute_hash()