test_config.py 15.4 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import copy
4
from contextlib import nullcontext
5
from unittest.mock import patch
6

7
import pytest
8
from pydantic import ValidationError
9
10

from vllm.compilation.counter import compilation_counter
11
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
12
from vllm.config import CompilationConfig, CUDAGraphMode, ParallelConfig, VllmConfig
13
from vllm.config.compilation import CompilationMode, PassConfig
14
15
from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform
16
from vllm.utils.torch_utils import _is_torch_equal_or_newer
17

18
19
20
# This import automatically registers `torch.ops.silly.attention`
from . import silly_attention  # noqa: F401

21
22

def test_version():
23
    # Test the version comparison logic using the private function
24
25
26
27
28
    assert _is_torch_equal_or_newer("2.8.0.dev20250624+cu128", "2.8.0.dev")
    assert _is_torch_equal_or_newer("2.8.0a0+gitc82a174", "2.8.0.dev")
    assert _is_torch_equal_or_newer("2.8.0", "2.8.0.dev")
    assert _is_torch_equal_or_newer("2.8.1", "2.8.0.dev")
    assert not _is_torch_equal_or_newer("2.7.1", "2.8.0.dev")
29
30


31
32
33
34
35
36
37
38
39
40
41
42
43
44
def test_copy_pass():
    vllm_config = VllmConfig()
    inductor_pass = FixFunctionalizationPass(vllm_config)
    copied_inductor_pass = copy.deepcopy(inductor_pass)
    assert (
        copied_inductor_pass.compilation_config.use_inductor_graph_partition
        == vllm_config.compilation_config.use_inductor_graph_partition
    )
    assert (
        copied_inductor_pass.compilation_config.splitting_ops
        == vllm_config.compilation_config.splitting_ops
    )


45
46
47
48
49
50
51
52
def test_custom_op():
    # proper syntax
    _ = CompilationConfig(custom_ops=["+quant_fp8", "-silu_and_mul"])

    with pytest.raises(ValueError, match="Invalid syntax '"):
        _ = CompilationConfig(custom_ops=["quant_fp8"])


53
54
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@pytest.mark.forked
55
56
57
58
59
# NB: We don't test VLLM_DISABLE_COMPILE_CACHE=0 because that depends
# on the state of the cache directory on the current machine, which
# may be influenced by other tests.
@pytest.mark.parametrize("val", ["1"])
def test_VLLM_DISABLE_COMPILE_CACHE(vllm_runner, monkeypatch, val):
60
    # Disable multiprocessing so that the counter is in the same process
61
62
    monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
    monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", val)
63
64

    compilation_config = {
65
        "cudagraph_mode": CUDAGraphMode.NONE,  # speed things up a bit
66
67
    }
    with (
68
69
70
71
72
73
74
75
76
77
        compilation_counter.expect(
            num_cache_entries_updated=0, num_compiled_artifacts_saved=0
        ),
        # loading the model causes compilation (if enabled) to happen
        vllm_runner(
            "facebook/opt-125m",
            compilation_config=compilation_config,
            gpu_memory_utilization=0.4,
        ) as _,
    ):
78
79
80
        pass


81
82
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@pytest.mark.forked
83
84
85
86
87
88
89
90
91
92
93
94
@pytest.mark.parametrize(
    "cudagraph_mode,num_cudagraph_captured",
    [
        (CUDAGraphMode.NONE, 0),
        (CUDAGraphMode.FULL_DECODE_ONLY, 1),
        (CUDAGraphMode.PIECEWISE, 13),
        (CUDAGraphMode.FULL_AND_PIECEWISE, 14),
    ],
)
def test_use_cudagraphs(
    vllm_runner, monkeypatch, cudagraph_mode, num_cudagraph_captured
):
95
    # Disable multiprocessing so that the counter is in the same process
96
    monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
97
98
99

    compilation_config = {
        "cudagraph_capture_sizes": [100],
100
        "cudagraph_mode": cudagraph_mode,
101
    }
102
    num_gpu_runner_capture_triggers = 1 if cudagraph_mode != CUDAGraphMode.NONE else 0
103
    with (
104
105
        compilation_counter.expect(
            num_graphs_seen=1,
106
107
            num_gpu_runner_capture_triggers=num_gpu_runner_capture_triggers,
            num_cudagraph_captured=num_cudagraph_captured,
108
109
110
111
112
113
114
115
        ),
        # loading the model causes compilation (if enabled) to happen
        vllm_runner(
            "facebook/opt-125m",
            compilation_config=compilation_config,
            gpu_memory_utilization=0.4,
        ) as _,
    ):
116
        pass
117
118
119
120


# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@pytest.mark.forked
121
def test_stock_torch_compile(vllm_runner, monkeypatch):
122
    # Disable multiprocessing so that the counter is in the same process
123
    monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
124
125

    with (
126
        compilation_counter.expect(stock_torch_compile_count=1),
127
128
129
        # loading the model causes compilation (if enabled) to happen
        vllm_runner(
            "facebook/opt-125m",
130
            compilation_config={"mode": CompilationMode.STOCK_TORCH_COMPILE},
131
132
133
            gpu_memory_utilization=0.4,
        ) as _,
    ):
134
135
136
137
138
139
140
        pass


# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@pytest.mark.forked
def test_no_compilation(vllm_runner, monkeypatch):
    # Disable multiprocessing so that the counter is in the same process
141
    monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
142
    with (
143
        compilation_counter.expect(num_graphs_seen=0, stock_torch_compile_count=0),
144
145
146
        # loading the model causes compilation (if enabled) to happen
        vllm_runner(
            "facebook/opt-125m",
147
            compilation_config={"mode": CompilationMode.NONE},
148
149
150
            gpu_memory_utilization=0.4,
        ) as _,
    ):
151
152
153
154
155
156
157
        pass


# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@pytest.mark.forked
def test_enforce_eager(vllm_runner, monkeypatch):
    # Disable multiprocessing so that the counter is in the same process
158
    monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
159
160

    with (
161
        compilation_counter.expect(num_graphs_seen=0, stock_torch_compile_count=0),
162
163
164
165
166
        # loading the model causes compilation (if enabled) to happen
        vllm_runner(
            "facebook/opt-125m", enforce_eager=True, gpu_memory_utilization=0.4
        ) as _,
    ):
167
        pass
168
169
170
171
172


def test_splitting_ops_dynamic():
    # Default config
    config = VllmConfig()
173
174
    # Default V1 config leaves cudagraph mode unset; splitting ops are only
    # populated when the engine decides to use piecewise compilation.
175
176
    assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE
    assert config.compilation_config.splitting_ops_contain_attention()
177
178

    # When use_inductor_graph_partition=True
179
180
181
182
183
    config = VllmConfig(
        compilation_config=CompilationConfig(
            mode=CompilationMode.VLLM_COMPILE,
            use_inductor_graph_partition=True,
            splitting_ops=["vllm::unified_attention"],
184
        )
185
186
187
188
    )
    # with inductor partition we use splitting_ops directly for
    # partition rules
    assert config.compilation_config.splitting_ops == ["vllm::unified_attention"]
189

190
    # When attn_fusion pass enabled.
191
192
    config = VllmConfig(
        compilation_config=CompilationConfig(
193
            mode=CompilationMode.VLLM_COMPILE,
194
            pass_config=PassConfig(fuse_attn_quant=True, eliminate_noops=True),
195
196
197
198
            custom_ops=["+quant_fp8"],
            cudagraph_mode=CUDAGraphMode.PIECEWISE,
        )
    )
199
200
201
    assert config.compilation_config.splitting_ops == []
    # cudagraph mode also fall back to FULL
    assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL
202

203
204
205
    # splitting_ops can not contain attention ops when attn_fusion
    # pass enabled.
    with pytest.raises(ValidationError):
206
207
        config = VllmConfig(
            compilation_config=CompilationConfig(
208
                mode=CompilationMode.VLLM_COMPILE,
209
                pass_config=PassConfig(fuse_attn_quant=True, eliminate_noops=True),
210
211
                custom_ops=["+quant_fp8"],
                cudagraph_mode=CUDAGraphMode.PIECEWISE,
212
213
                # work around for accessing all attntion ops
                splitting_ops=CompilationConfig()._attention_ops,
214
215
            )
        )
216
217
218
219
220
221

    # When both use_inductor_graph_partition and attn_fusion pass enabled.
    config = VllmConfig(
        compilation_config=CompilationConfig(
            mode=CompilationMode.VLLM_COMPILE,
            use_inductor_graph_partition=True,
222
            pass_config=PassConfig(fuse_attn_quant=True, eliminate_noops=True),
223
224
225
226
227
228
229
            custom_ops=["+quant_fp8"],
            cudagraph_mode=CUDAGraphMode.PIECEWISE,
        )
    )
    # With inductor graph partition, attn_fusion and splitting_ops
    # work together. Default splitting_ops include attention ops.
    assert config.compilation_config.splitting_ops_contain_attention()
230
    # fuse_attn_quant is directly supported under
231
232
233
    # use_inductor_graph_partition=True, and cudagraph_mode
    # is unchanged.
    assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE
234
235


236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
def test_moe_splitting_ops_deepep_ht_piecewise():
    # Non-inductor, non-attn-fusion case: DeepEP HT with dp>1
    # should add MoE ops to splitting_ops on top of attention ops.
    config = VllmConfig(
        parallel_config=ParallelConfig(
            all2all_backend="deepep_high_throughput",
            data_parallel_size=8,
        ),
        compilation_config=CompilationConfig(
            mode=CompilationMode.VLLM_COMPILE,
        ),
    )
    splitting_ops = config.compilation_config.splitting_ops
    assert splitting_ops is not None
    assert "vllm::moe_forward" in splitting_ops
    assert "vllm::moe_forward_shared" in splitting_ops


def test_moe_splitting_ops_deepep_ht_inductor_partition():
    # Inductor partition case: user-provided splitting_ops should be
    # preserved and MoE ops should be appended for DeepEP HT with dp>1.
    config = VllmConfig(
        parallel_config=ParallelConfig(
            all2all_backend="deepep_high_throughput",
            data_parallel_size=8,
        ),
        compilation_config=CompilationConfig(
            mode=CompilationMode.VLLM_COMPILE,
            use_inductor_graph_partition=True,
            splitting_ops=[
                "vllm::unified_attention",
                "vllm::moe_forward",
                "vllm::moe_forward_shared",
            ],
        ),
    )
    splitting_ops = config.compilation_config.splitting_ops
    assert splitting_ops == [
        "vllm::unified_attention",
        "vllm::moe_forward",
        "vllm::moe_forward_shared",
    ]


def test_moe_splitting_ops_deepep_ht_attn_fusion_no_inductor():
    # Pure attn-fusion case without inductor partition: even with
    # DeepEP HT and dp>1, we should not re-enable piecewise compilation
    # or add MoE ops into splitting_ops.
    config = VllmConfig(
        parallel_config=ParallelConfig(
            all2all_backend="deepep_high_throughput",
            data_parallel_size=8,
        ),
        compilation_config=CompilationConfig(
            mode=CompilationMode.VLLM_COMPILE,
291
            pass_config={"fuse_attn_quant": True, "eliminate_noops": True},
292
293
294
295
296
297
298
299
            custom_ops=["+quant_fp8"],
            cudagraph_mode=CUDAGraphMode.PIECEWISE,
        ),
    )
    assert config.compilation_config.splitting_ops == []
    assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL


300
def test_should_split():
301
302
    import torch

303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
    from vllm.compilation.partition_rules import should_split

    graph = torch.fx.Graph()
    node = torch.fx.Node(
        graph=graph,
        name="dummy_node",
        op="call_function",
        target=torch.ops.aten.add.default,
        args=(),
        kwargs={},
    )

    # supports OpOverloadPacket
    splitting_ops = ["aten::add"]
    assert should_split(node, splitting_ops)

    # supports OpOverload
    splitting_ops = ["aten::add.default"]
    assert should_split(node, splitting_ops)

    # supports OpOverload
    splitting_ops = ["aten::add.Tensor"]
    assert not should_split(node, splitting_ops)

    q, k, v, out = [torch.randn(1)] * 4

    # supports custom ops as OpOverloadPacket
    node = torch.fx.Node(
        graph=graph,
        name="dummy_node",
        op="call_function",
        target=torch.ops.silly.attention,
        args=(q, k, v, out),
        kwargs={},
    )

    splitting_ops = ["silly::attention"]
    assert should_split(node, splitting_ops)

    # supports custom ops as OpOverload
    node = torch.fx.Node(
        graph=graph,
        name="dummy_node",
        op="call_function",
        target=torch.ops.silly.attention.default,
        args=(q, k, v, out),
        kwargs={},
    )

    splitting_ops = ["silly::attention"]
    assert should_split(node, splitting_ops)

    splitting_ops = ["silly::attention.default"]
    assert should_split(node, splitting_ops)
357
358
359
360
361
362
363
364
365
366
367


@pytest.mark.skipif(
    not current_platform.support_static_graph_mode(),
    reason="Skip if not cudagraph mode supported",
)
@pytest.mark.parametrize(
    (
        "cudagraph_capture_sizes",
        "max_cudagraph_capture_size",
        "tp_size",
368
        "enable_sp",
369
        "max_num_batched_tokens",
370
        "cudagraph_mode",
371
372
373
        "expected_max_size",
    ),
    [
374
375
376
377
378
379
380
381
382
383
384
385
386
387
        (None, None, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 256),
        ([1, 2, 4], 4, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 4),
        (
            [1, 2, 4],
            8,
            1,
            False,
            2048,
            CUDAGraphMode.FULL_AND_PIECEWISE,
            ValidationError,
        ),
        ([1, 256], None, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 256),
        ([], None, 1, False, 2048, CUDAGraphMode.NONE, 0),
        (None, 0, 1, False, 2048, CUDAGraphMode.NONE, 0),
388
        # truncated to nearest multiple of 8 or 16
389
390
391
392
393
394
395
        (None, 257, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 256),
        # max from list
        ([1, 2, 4, 15], None, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 15),
        # filtered out 15 due to SP
        ([1, 2, 4, 15], None, 2, True, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 4),
        # limited by the max_tokens
        ([1, 2, 4, 15], None, 1, False, 8, CUDAGraphMode.FULL_AND_PIECEWISE, 4),
396
        # the list should contain at least 1 element when use cudagraph
397
        ([], None, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, ValidationError),
398
        # the max capturing size should be >= 1 when use cudagraph
399
        (None, 0, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, ValidationError),
400
401
402
403
404
405
    ],
)
def test_cudagraph_sizes_post_init(
    cudagraph_capture_sizes,
    max_cudagraph_capture_size,
    tp_size,
406
    enable_sp,
407
    max_num_batched_tokens,
408
    cudagraph_mode,
409
410
411
    expected_max_size,
):
    ctx = nullcontext()
412
    if expected_max_size == ValidationError:
413
414
        ctx = pytest.raises(expected_max_size)

415
416
417
418
    with (
        ctx,
        patch("vllm.config.parallel.cuda_device_count_stateless", return_value=tp_size),
    ):
419
420
421
        compilation_config = CompilationConfig(
            cudagraph_capture_sizes=cudagraph_capture_sizes,
            max_cudagraph_capture_size=max_cudagraph_capture_size,
422
423
424
425
426
427
            pass_config=PassConfig(
                enable_sp=enable_sp,
                fuse_norm_quant=True,
                fuse_act_quant=True,
                eliminate_noops=True,
            ),
428
429
430
431
432
            cudagraph_mode=cudagraph_mode,
        )
        engine_args = EngineArgs(
            model="facebook/opt-125m",
            tensor_parallel_size=tp_size,
433
            max_num_seqs=min(max_num_batched_tokens, 128),
434
435
436
437
438
            max_num_batched_tokens=max_num_batched_tokens,
            compilation_config=compilation_config,
        )
        vllm_config = engine_args.create_engine_config()

439
440
441
442
        assert (
            vllm_config.compilation_config.max_cudagraph_capture_size
            == expected_max_size
        )