test_config.py 13.2 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import copy
4
from contextlib import nullcontext
5
from unittest.mock import patch
6

7
import pytest
8
from pydantic import ValidationError
9
10

from vllm.compilation.counter import compilation_counter
11
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
12
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
13
from vllm.config.compilation import CompilationMode
14
15
from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform
16
from vllm.utils.torch_utils import _is_torch_equal_or_newer
17
18
19


def test_version():
20
    # Test the version comparison logic using the private function
21
22
23
24
25
    assert _is_torch_equal_or_newer("2.8.0.dev20250624+cu128", "2.8.0.dev")
    assert _is_torch_equal_or_newer("2.8.0a0+gitc82a174", "2.8.0.dev")
    assert _is_torch_equal_or_newer("2.8.0", "2.8.0.dev")
    assert _is_torch_equal_or_newer("2.8.1", "2.8.0.dev")
    assert not _is_torch_equal_or_newer("2.7.1", "2.8.0.dev")
26
27


28
29
30
31
32
33
34
35
36
37
38
39
40
41
def test_copy_pass():
    vllm_config = VllmConfig()
    inductor_pass = FixFunctionalizationPass(vllm_config)
    copied_inductor_pass = copy.deepcopy(inductor_pass)
    assert (
        copied_inductor_pass.compilation_config.use_inductor_graph_partition
        == vllm_config.compilation_config.use_inductor_graph_partition
    )
    assert (
        copied_inductor_pass.compilation_config.splitting_ops
        == vllm_config.compilation_config.splitting_ops
    )


42
43
44
45
46
47
48
49
def test_custom_op():
    # proper syntax
    _ = CompilationConfig(custom_ops=["+quant_fp8", "-silu_and_mul"])

    with pytest.raises(ValueError, match="Invalid syntax '"):
        _ = CompilationConfig(custom_ops=["quant_fp8"])


50
51
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@pytest.mark.forked
52
53
54
55
56
# NB: We don't test VLLM_DISABLE_COMPILE_CACHE=0 because that depends
# on the state of the cache directory on the current machine, which
# may be influenced by other tests.
@pytest.mark.parametrize("val", ["1"])
def test_VLLM_DISABLE_COMPILE_CACHE(vllm_runner, monkeypatch, val):
57
    # Disable multiprocessing so that the counter is in the same process
58
59
    monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
    monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", val)
60
61

    compilation_config = {
62
        "cudagraph_mode": CUDAGraphMode.NONE,  # speed things up a bit
63
64
    }
    with (
65
66
67
68
69
70
71
72
73
74
        compilation_counter.expect(
            num_cache_entries_updated=0, num_compiled_artifacts_saved=0
        ),
        # loading the model causes compilation (if enabled) to happen
        vllm_runner(
            "facebook/opt-125m",
            compilation_config=compilation_config,
            gpu_memory_utilization=0.4,
        ) as _,
    ):
75
76
77
        pass


78
79
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@pytest.mark.forked
80
81
82
83
84
85
86
87
88
89
90
91
@pytest.mark.parametrize(
    "cudagraph_mode,num_cudagraph_captured",
    [
        (CUDAGraphMode.NONE, 0),
        (CUDAGraphMode.FULL_DECODE_ONLY, 1),
        (CUDAGraphMode.PIECEWISE, 13),
        (CUDAGraphMode.FULL_AND_PIECEWISE, 14),
    ],
)
def test_use_cudagraphs(
    vllm_runner, monkeypatch, cudagraph_mode, num_cudagraph_captured
):
92
    # Disable multiprocessing so that the counter is in the same process
93
    monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
94
95
96

    compilation_config = {
        "cudagraph_capture_sizes": [100],
97
        "cudagraph_mode": cudagraph_mode,
98
    }
99
    num_gpu_runner_capture_triggers = 1 if cudagraph_mode != CUDAGraphMode.NONE else 0
100
    with (
101
102
        compilation_counter.expect(
            num_graphs_seen=1,
103
104
            num_gpu_runner_capture_triggers=num_gpu_runner_capture_triggers,
            num_cudagraph_captured=num_cudagraph_captured,
105
106
107
108
109
110
111
112
        ),
        # loading the model causes compilation (if enabled) to happen
        vllm_runner(
            "facebook/opt-125m",
            compilation_config=compilation_config,
            gpu_memory_utilization=0.4,
        ) as _,
    ):
113
        pass
114
115
116
117


# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@pytest.mark.forked
118
def test_stock_torch_compile(vllm_runner, monkeypatch):
119
    # Disable multiprocessing so that the counter is in the same process
120
    monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
121
122

    with (
123
        compilation_counter.expect(stock_torch_compile_count=1),
124
125
126
        # loading the model causes compilation (if enabled) to happen
        vllm_runner(
            "facebook/opt-125m",
127
            compilation_config={"mode": CompilationMode.STOCK_TORCH_COMPILE},
128
129
130
            gpu_memory_utilization=0.4,
        ) as _,
    ):
131
132
133
134
135
136
137
        pass


# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@pytest.mark.forked
def test_no_compilation(vllm_runner, monkeypatch):
    # Disable multiprocessing so that the counter is in the same process
138
    monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
139
    with (
140
        compilation_counter.expect(num_graphs_seen=0, stock_torch_compile_count=0),
141
142
143
        # loading the model causes compilation (if enabled) to happen
        vllm_runner(
            "facebook/opt-125m",
144
            compilation_config={"mode": CompilationMode.NONE},
145
146
147
            gpu_memory_utilization=0.4,
        ) as _,
    ):
148
149
150
151
152
153
154
        pass


# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@pytest.mark.forked
def test_enforce_eager(vllm_runner, monkeypatch):
    # Disable multiprocessing so that the counter is in the same process
155
    monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
156
157

    with (
158
        compilation_counter.expect(num_graphs_seen=0, stock_torch_compile_count=0),
159
160
161
162
163
        # loading the model causes compilation (if enabled) to happen
        vllm_runner(
            "facebook/opt-125m", enforce_eager=True, gpu_memory_utilization=0.4
        ) as _,
    ):
164
        pass
165
166
167
168
169


def test_splitting_ops_dynamic():
    # Default config
    config = VllmConfig()
170
171
172
173
    # Default V1 config leaves cudagraph mode unset; splitting ops are only
    # populated when the engine decides to use piecewise compilation.
    assert config.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
    assert not config.compilation_config.splitting_ops_contain_attention()
174
175

    # When use_inductor_graph_partition=True
176
177
178
179
180
    config = VllmConfig(
        compilation_config=CompilationConfig(
            mode=CompilationMode.VLLM_COMPILE,
            use_inductor_graph_partition=True,
            splitting_ops=["vllm::unified_attention"],
181
        )
182
183
184
185
    )
    # with inductor partition we use splitting_ops directly for
    # partition rules
    assert config.compilation_config.splitting_ops == ["vllm::unified_attention"]
186

187
    # When attn_fusion pass enabled.
188
189
    config = VllmConfig(
        compilation_config=CompilationConfig(
190
            mode=CompilationMode.VLLM_COMPILE,
191
192
193
194
195
            pass_config={"enable_attn_fusion": True, "enable_noop": True},
            custom_ops=["+quant_fp8"],
            cudagraph_mode=CUDAGraphMode.PIECEWISE,
        )
    )
196
197
198
    assert config.compilation_config.splitting_ops == []
    # cudagraph mode also fall back to FULL
    assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL
199

200
201
202
    # splitting_ops can not contain attention ops when attn_fusion
    # pass enabled.
    with pytest.raises(ValidationError):
203
204
        config = VllmConfig(
            compilation_config=CompilationConfig(
205
                mode=CompilationMode.VLLM_COMPILE,
206
207
208
                pass_config={"enable_attn_fusion": True, "enable_noop": True},
                custom_ops=["+quant_fp8"],
                cudagraph_mode=CUDAGraphMode.PIECEWISE,
209
210
                # work around for accessing all attntion ops
                splitting_ops=CompilationConfig()._attention_ops,
211
212
            )
        )
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230

    # When both use_inductor_graph_partition and attn_fusion pass enabled.
    config = VllmConfig(
        compilation_config=CompilationConfig(
            mode=CompilationMode.VLLM_COMPILE,
            use_inductor_graph_partition=True,
            pass_config={"enable_attn_fusion": True, "enable_noop": True},
            custom_ops=["+quant_fp8"],
            cudagraph_mode=CUDAGraphMode.PIECEWISE,
        )
    )
    # With inductor graph partition, attn_fusion and splitting_ops
    # work together. Default splitting_ops include attention ops.
    assert config.compilation_config.splitting_ops_contain_attention()
    # enable_attn_fusion is directly supported under
    # use_inductor_graph_partition=True, and cudagraph_mode
    # is unchanged.
    assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE
231
232


233
def test_should_split():
234
235
    import torch

236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
    from vllm.compilation.partition_rules import should_split

    graph = torch.fx.Graph()
    node = torch.fx.Node(
        graph=graph,
        name="dummy_node",
        op="call_function",
        target=torch.ops.aten.add.default,
        args=(),
        kwargs={},
    )

    # supports OpOverloadPacket
    splitting_ops = ["aten::add"]
    assert should_split(node, splitting_ops)

    # supports OpOverload
    splitting_ops = ["aten::add.default"]
    assert should_split(node, splitting_ops)

    # supports OpOverload
    splitting_ops = ["aten::add.Tensor"]
    assert not should_split(node, splitting_ops)

    @torch.library.custom_op(
        "silly::attention",
        mutates_args=["out"],
263
    )
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
    def attention(
        q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor
    ) -> None:
        out.copy_(q + k + v)

    q, k, v, out = [torch.randn(1)] * 4

    # supports custom ops as OpOverloadPacket
    node = torch.fx.Node(
        graph=graph,
        name="dummy_node",
        op="call_function",
        target=torch.ops.silly.attention,
        args=(q, k, v, out),
        kwargs={},
    )

    splitting_ops = ["silly::attention"]
    assert should_split(node, splitting_ops)

    # supports custom ops as OpOverload
    node = torch.fx.Node(
        graph=graph,
        name="dummy_node",
        op="call_function",
        target=torch.ops.silly.attention.default,
        args=(q, k, v, out),
        kwargs={},
    )

    splitting_ops = ["silly::attention"]
    assert should_split(node, splitting_ops)

    splitting_ops = ["silly::attention.default"]
    assert should_split(node, splitting_ops)
299
300
301
302
303
304
305
306
307
308
309
310
311


@pytest.mark.skipif(
    not current_platform.support_static_graph_mode(),
    reason="Skip if not cudagraph mode supported",
)
@pytest.mark.parametrize(
    (
        "cudagraph_capture_sizes",
        "max_cudagraph_capture_size",
        "tp_size",
        "enable_sequence_parallelism",
        "max_num_batched_tokens",
312
        "cudagraph_mode",
313
314
315
        "expected_max_size",
    ),
    [
316
317
318
319
320
321
322
323
324
325
326
327
328
329
        (None, None, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 256),
        ([1, 2, 4], 4, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 4),
        (
            [1, 2, 4],
            8,
            1,
            False,
            2048,
            CUDAGraphMode.FULL_AND_PIECEWISE,
            ValidationError,
        ),
        ([1, 256], None, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 256),
        ([], None, 1, False, 2048, CUDAGraphMode.NONE, 0),
        (None, 0, 1, False, 2048, CUDAGraphMode.NONE, 0),
330
        # truncated to nearest multiple of 8 or 16
331
332
333
334
335
336
337
        (None, 257, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 256),
        # max from list
        ([1, 2, 4, 15], None, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 15),
        # filtered out 15 due to SP
        ([1, 2, 4, 15], None, 2, True, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 4),
        # limited by the max_tokens
        ([1, 2, 4, 15], None, 1, False, 8, CUDAGraphMode.FULL_AND_PIECEWISE, 4),
338
        # the list should contain at least 1 element when use cudagraph
339
        ([], None, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, ValidationError),
340
        # the max capturing size should be >= 1 when use cudagraph
341
        (None, 0, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, ValidationError),
342
343
344
345
346
347
348
349
    ],
)
def test_cudagraph_sizes_post_init(
    cudagraph_capture_sizes,
    max_cudagraph_capture_size,
    tp_size,
    enable_sequence_parallelism,
    max_num_batched_tokens,
350
    cudagraph_mode,
351
352
353
    expected_max_size,
):
    ctx = nullcontext()
354
    if expected_max_size == ValidationError:
355
356
        ctx = pytest.raises(expected_max_size)

357
358
359
360
    with (
        ctx,
        patch("vllm.config.parallel.cuda_device_count_stateless", return_value=tp_size),
    ):
361
362
363
364
365
366
367
368
369
370
371
372
373
        compilation_config = CompilationConfig(
            cudagraph_capture_sizes=cudagraph_capture_sizes,
            max_cudagraph_capture_size=max_cudagraph_capture_size,
            pass_config={
                "enable_sequence_parallelism": enable_sequence_parallelism,
                "enable_fusion": True,
                "enable_noop": True,
            },
            cudagraph_mode=cudagraph_mode,
        )
        engine_args = EngineArgs(
            model="facebook/opt-125m",
            tensor_parallel_size=tp_size,
374
            max_num_seqs=min(max_num_batched_tokens, 128),
375
376
377
378
379
            max_num_batched_tokens=max_num_batched_tokens,
            compilation_config=compilation_config,
        )
        vllm_config = engine_args.create_engine_config()

380
381
382
383
        assert (
            vllm_config.compilation_config.max_cudagraph_capture_size
            == expected_max_size
        )