test_config.py 8.85 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
import copy

5
6
7
import pytest

from vllm.compilation.counter import compilation_counter
8
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
9
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
10
from vllm.config.compilation import CompilationMode
11
from vllm.utils.torch_utils import _is_torch_equal_or_newer, is_torch_equal_or_newer
12
13
14


def test_version():
15
    # Test the version comparison logic using the private function
16
17
18
19
20
    assert _is_torch_equal_or_newer("2.8.0.dev20250624+cu128", "2.8.0.dev")
    assert _is_torch_equal_or_newer("2.8.0a0+gitc82a174", "2.8.0.dev")
    assert _is_torch_equal_or_newer("2.8.0", "2.8.0.dev")
    assert _is_torch_equal_or_newer("2.8.1", "2.8.0.dev")
    assert not _is_torch_equal_or_newer("2.7.1", "2.8.0.dev")
21
22


23
def test_use_cudagraphs_dynamic():
24
    vllm_config = VllmConfig()
25
26
27
    # Default V1 configuration now starts without cudagraphs enabled; the
    # engine decides when to capture based on runtime settings instead of a
    # blanket default.
28
29
30
    assert vllm_config.compilation_config.use_cudagraph


31
32
33
34
35
36
37
38
39
40
41
42
43
44
def test_copy_pass():
    vllm_config = VllmConfig()
    inductor_pass = FixFunctionalizationPass(vllm_config)
    copied_inductor_pass = copy.deepcopy(inductor_pass)
    assert (
        copied_inductor_pass.compilation_config.use_inductor_graph_partition
        == vllm_config.compilation_config.use_inductor_graph_partition
    )
    assert (
        copied_inductor_pass.compilation_config.splitting_ops
        == vllm_config.compilation_config.splitting_ops
    )


45
46
47
48
49
50
51
52
def test_custom_op():
    # proper syntax
    _ = CompilationConfig(custom_ops=["+quant_fp8", "-silu_and_mul"])

    with pytest.raises(ValueError, match="Invalid syntax '"):
        _ = CompilationConfig(custom_ops=["quant_fp8"])


53
54
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@pytest.mark.forked
55
56
57
58
59
# NB: We don't test VLLM_DISABLE_COMPILE_CACHE=0 because that depends
# on the state of the cache directory on the current machine, which
# may be influenced by other tests.
@pytest.mark.parametrize("val", ["1"])
def test_VLLM_DISABLE_COMPILE_CACHE(vllm_runner, monkeypatch, val):
60
    # Disable multiprocessing so that the counter is in the same process
61
62
    monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
    monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", val)
63
64
65
66
67

    compilation_config = {
        "use_cudagraph": False,  # speed things up a bit
    }
    with (
68
69
70
71
72
73
74
75
76
77
        compilation_counter.expect(
            num_cache_entries_updated=0, num_compiled_artifacts_saved=0
        ),
        # loading the model causes compilation (if enabled) to happen
        vllm_runner(
            "facebook/opt-125m",
            compilation_config=compilation_config,
            gpu_memory_utilization=0.4,
        ) as _,
    ):
78
79
80
        pass


81
82
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@pytest.mark.forked
83
@pytest.mark.parametrize("enabled", [True, False])
84
85
def test_use_cudagraphs(vllm_runner, monkeypatch, enabled):
    # Disable multiprocessing so that the counter is in the same process
86
    monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
87
88
89
90
91
92

    compilation_config = {
        "cudagraph_capture_sizes": [100],
        "use_cudagraph": enabled,
    }
    with (
93
94
95
96
97
98
99
100
101
102
103
104
        compilation_counter.expect(
            num_graphs_seen=1,
            num_gpu_runner_capture_triggers=1 if enabled else 0,
            num_cudagraph_captured=13 if enabled else 0,
        ),
        # loading the model causes compilation (if enabled) to happen
        vllm_runner(
            "facebook/opt-125m",
            compilation_config=compilation_config,
            gpu_memory_utilization=0.4,
        ) as _,
    ):
105
        pass
106
107
108
109


# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@pytest.mark.forked
110
def test_stock_torch_compile(vllm_runner, monkeypatch):
111
    # Disable multiprocessing so that the counter is in the same process
112
    monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
113
114

    with (
115
        compilation_counter.expect(stock_torch_compile_count=1),
116
117
118
        # loading the model causes compilation (if enabled) to happen
        vllm_runner(
            "facebook/opt-125m",
119
            compilation_config={"mode": CompilationMode.STOCK_TORCH_COMPILE},
120
121
122
            gpu_memory_utilization=0.4,
        ) as _,
    ):
123
124
125
126
127
128
129
        pass


# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@pytest.mark.forked
def test_no_compilation(vllm_runner, monkeypatch):
    # Disable multiprocessing so that the counter is in the same process
130
    monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
131
    with (
132
        compilation_counter.expect(num_graphs_seen=0, stock_torch_compile_count=0),
133
134
135
        # loading the model causes compilation (if enabled) to happen
        vllm_runner(
            "facebook/opt-125m",
136
            compilation_config={"mode": CompilationMode.NONE},
137
138
139
            gpu_memory_utilization=0.4,
        ) as _,
    ):
140
141
142
143
144
145
146
        pass


# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@pytest.mark.forked
def test_enforce_eager(vllm_runner, monkeypatch):
    # Disable multiprocessing so that the counter is in the same process
147
    monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
148
149

    with (
150
        compilation_counter.expect(num_graphs_seen=0, stock_torch_compile_count=0),
151
152
153
154
155
        # loading the model causes compilation (if enabled) to happen
        vllm_runner(
            "facebook/opt-125m", enforce_eager=True, gpu_memory_utilization=0.4
        ) as _,
    ):
156
        pass
157
158
159
160
161


def test_splitting_ops_dynamic():
    # Default config
    config = VllmConfig()
162
163
164
165
    # Default V1 config leaves cudagraph mode unset; splitting ops are only
    # populated when the engine decides to use piecewise compilation.
    assert config.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
    assert not config.compilation_config.splitting_ops_contain_attention()
166
167

    # When use_inductor_graph_partition=True
168
    if is_torch_equal_or_newer("2.9.0.dev"):
169
170
        config = VllmConfig(
            compilation_config=CompilationConfig(
171
                mode=CompilationMode.VLLM_COMPILE,
172
173
                use_inductor_graph_partition=True,
                splitting_ops=["vllm::unified_attention"],
174
175
            )
        )
176
177
178
        # with inductor partition we use splitting_ops directly for
        # partition rules
        assert config.compilation_config.splitting_ops == ["vllm::unified_attention"]
179

180
    # When attn_fusion pass enabled, splitting_ops now default to attention ops.
181
182
    config = VllmConfig(
        compilation_config=CompilationConfig(
183
            mode=CompilationMode.VLLM_COMPILE,
184
185
186
187
188
            pass_config={"enable_attn_fusion": True, "enable_noop": True},
            custom_ops=["+quant_fp8"],
            cudagraph_mode=CUDAGraphMode.PIECEWISE,
        )
    )
189
190
191
192
    # With the new simplified logic, attention fusion works with splitting_ops
    assert config.compilation_config.splitting_ops_contain_attention()
    # cudagraph mode remains PIECEWISE
    assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE
193
194

    # When both use_inductor_graph_partition and attn_fusion pass enabled.
195
    if is_torch_equal_or_newer("2.9.0.dev"):
196
197
        config = VllmConfig(
            compilation_config=CompilationConfig(
198
                mode=CompilationMode.VLLM_COMPILE,
199
200
201
202
203
204
                use_inductor_graph_partition=True,
                pass_config={"enable_attn_fusion": True, "enable_noop": True},
                custom_ops=["+quant_fp8"],
                cudagraph_mode=CUDAGraphMode.PIECEWISE,
            )
        )
205
206
207
208
        # With inductor graph partition, attn_fusion and splitting_ops
        # work together. Default splitting_ops include attention ops.
        assert config.compilation_config.splitting_ops_contain_attention()
        # enable_attn_fusion is directly supported under
209
210
        # use_inductor_graph_partition=True, and cudagraph_mode
        # is unchanged.
211
        assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235


def test_resolve_operator_overload():
    import torch

    from vllm.compilation.partition_rules import resolve_defined_ops

    # Test valid operator names
    resolved = resolve_defined_ops(["aten::mm.default", "aten::addmm.default"])
    assert len(resolved) == 2
    assert resolved[0] is torch.ops.aten.mm.default
    assert resolved[1] is torch.ops.aten.addmm.default

    # Test that invalid operators are skipped (not raising exceptions)
    resolved = resolve_defined_ops(
        [
            "aten::mm.default",
            "aten::nonexistent_op.default",  # This should be skipped
            "aten::addmm.default",
        ]
    )
    assert len(resolved) == 2  # Only 2 valid ops
    assert resolved[0] is torch.ops.aten.mm.default
    assert resolved[1] is torch.ops.aten.addmm.default