test_config.py 8.21 KB
Newer Older
1
2
3
4
5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest

from vllm.compilation.counter import compilation_counter
6
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
7
8
from vllm.config.compilation import CompilationLevel
from vllm.utils import _is_torch_equal_or_newer, is_torch_equal_or_newer
9
10
11


def test_version():
12
    # Test the version comparison logic using the private function
13
14
15
16
17
    assert _is_torch_equal_or_newer("2.8.0.dev20250624+cu128", "2.8.0.dev")
    assert _is_torch_equal_or_newer("2.8.0a0+gitc82a174", "2.8.0.dev")
    assert _is_torch_equal_or_newer("2.8.0", "2.8.0.dev")
    assert _is_torch_equal_or_newer("2.8.1", "2.8.0.dev")
    assert not _is_torch_equal_or_newer("2.7.1", "2.8.0.dev")
18
19


20
def test_use_cudagraphs_dynamic():
21
    vllm_config = VllmConfig()
22
23
24
    # Default V1 configuration now starts without cudagraphs enabled; the
    # engine decides when to capture based on runtime settings instead of a
    # blanket default.
25
26
27
    assert vllm_config.compilation_config.use_cudagraph


28
29
30
31
32
33
34
35
def test_custom_op():
    # proper syntax
    _ = CompilationConfig(custom_ops=["+quant_fp8", "-silu_and_mul"])

    with pytest.raises(ValueError, match="Invalid syntax '"):
        _ = CompilationConfig(custom_ops=["quant_fp8"])


36
37
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@pytest.mark.forked
38
39
40
41
42
# NB: We don't test VLLM_DISABLE_COMPILE_CACHE=0 because that depends
# on the state of the cache directory on the current machine, which
# may be influenced by other tests.
@pytest.mark.parametrize("val", ["1"])
def test_VLLM_DISABLE_COMPILE_CACHE(vllm_runner, monkeypatch, val):
43
    # Disable multiprocessing so that the counter is in the same process
44
45
    monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
    monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", val)
46
47
48
49
50

    compilation_config = {
        "use_cudagraph": False,  # speed things up a bit
    }
    with (
51
52
53
54
55
56
57
58
59
60
        compilation_counter.expect(
            num_cache_entries_updated=0, num_compiled_artifacts_saved=0
        ),
        # loading the model causes compilation (if enabled) to happen
        vllm_runner(
            "facebook/opt-125m",
            compilation_config=compilation_config,
            gpu_memory_utilization=0.4,
        ) as _,
    ):
61
62
63
        pass


64
65
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@pytest.mark.forked
66
@pytest.mark.parametrize("enabled", [True, False])
67
68
def test_use_cudagraphs(vllm_runner, monkeypatch, enabled):
    # Disable multiprocessing so that the counter is in the same process
69
    monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
70
71
72
73
74
75

    compilation_config = {
        "cudagraph_capture_sizes": [100],
        "use_cudagraph": enabled,
    }
    with (
76
77
78
79
80
81
82
83
84
85
86
87
        compilation_counter.expect(
            num_graphs_seen=1,
            num_gpu_runner_capture_triggers=1 if enabled else 0,
            num_cudagraph_captured=13 if enabled else 0,
        ),
        # loading the model causes compilation (if enabled) to happen
        vllm_runner(
            "facebook/opt-125m",
            compilation_config=compilation_config,
            gpu_memory_utilization=0.4,
        ) as _,
    ):
88
        pass
89
90
91
92
93
94


# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@pytest.mark.forked
def test_dynamo_as_is(vllm_runner, monkeypatch):
    # Disable multiprocessing so that the counter is in the same process
95
    monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
96
97

    with (
98
99
100
101
102
103
104
105
        compilation_counter.expect(dynamo_as_is_count=1),
        # loading the model causes compilation (if enabled) to happen
        vllm_runner(
            "facebook/opt-125m",
            compilation_config={"level": 1},
            gpu_memory_utilization=0.4,
        ) as _,
    ):
106
107
108
109
110
111
112
        pass


# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@pytest.mark.forked
def test_no_compilation(vllm_runner, monkeypatch):
    # Disable multiprocessing so that the counter is in the same process
113
    monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
114
    with (
115
116
117
118
119
120
121
122
        compilation_counter.expect(num_graphs_seen=0, dynamo_as_is_count=0),
        # loading the model causes compilation (if enabled) to happen
        vllm_runner(
            "facebook/opt-125m",
            compilation_config={"level": 0},
            gpu_memory_utilization=0.4,
        ) as _,
    ):
123
124
125
126
127
128
129
        pass


# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@pytest.mark.forked
def test_enforce_eager(vllm_runner, monkeypatch):
    # Disable multiprocessing so that the counter is in the same process
130
    monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
131
132

    with (
133
134
135
136
137
138
        compilation_counter.expect(num_graphs_seen=0, dynamo_as_is_count=0),
        # loading the model causes compilation (if enabled) to happen
        vllm_runner(
            "facebook/opt-125m", enforce_eager=True, gpu_memory_utilization=0.4
        ) as _,
    ):
139
        pass
140
141
142
143
144


def test_splitting_ops_dynamic():
    # Default config
    config = VllmConfig()
145
146
147
148
    # Default V1 config leaves cudagraph mode unset; splitting ops are only
    # populated when the engine decides to use piecewise compilation.
    assert config.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
    assert not config.compilation_config.splitting_ops_contain_attention()
149
150

    # When use_inductor_graph_partition=True
151
    if is_torch_equal_or_newer("2.9.0.dev"):
152
153
        config = VllmConfig(
            compilation_config=CompilationConfig(
154
155
156
                level=CompilationLevel.PIECEWISE,
                use_inductor_graph_partition=True,
                splitting_ops=["vllm::unified_attention"],
157
158
            )
        )
159
160
161
        # with inductor partition we use splitting_ops directly for
        # partition rules
        assert config.compilation_config.splitting_ops == ["vllm::unified_attention"]
162

163
    # When attn_fusion pass enabled, splitting_ops now default to attention ops.
164
165
    config = VllmConfig(
        compilation_config=CompilationConfig(
166
            level=CompilationLevel.PIECEWISE,
167
168
169
170
171
            pass_config={"enable_attn_fusion": True, "enable_noop": True},
            custom_ops=["+quant_fp8"],
            cudagraph_mode=CUDAGraphMode.PIECEWISE,
        )
    )
172
173
174
175
    # With the new simplified logic, attention fusion works with splitting_ops
    assert config.compilation_config.splitting_ops_contain_attention()
    # cudagraph mode remains PIECEWISE
    assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE
176
177

    # When both use_inductor_graph_partition and attn_fusion pass enabled.
178
    if is_torch_equal_or_newer("2.9.0.dev"):
179
180
        config = VllmConfig(
            compilation_config=CompilationConfig(
181
                level=CompilationLevel.PIECEWISE,
182
183
184
185
186
187
                use_inductor_graph_partition=True,
                pass_config={"enable_attn_fusion": True, "enable_noop": True},
                custom_ops=["+quant_fp8"],
                cudagraph_mode=CUDAGraphMode.PIECEWISE,
            )
        )
188
189
190
191
        # With inductor graph partition, attn_fusion and splitting_ops
        # work together. Default splitting_ops include attention ops.
        assert config.compilation_config.splitting_ops_contain_attention()
        # enable_attn_fusion is directly supported under
192
193
        # use_inductor_graph_partition=True, and cudagraph_mode
        # is unchanged.
194
        assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218


def test_resolve_operator_overload():
    import torch

    from vllm.compilation.partition_rules import resolve_defined_ops

    # Test valid operator names
    resolved = resolve_defined_ops(["aten::mm.default", "aten::addmm.default"])
    assert len(resolved) == 2
    assert resolved[0] is torch.ops.aten.mm.default
    assert resolved[1] is torch.ops.aten.addmm.default

    # Test that invalid operators are skipped (not raising exceptions)
    resolved = resolve_defined_ops(
        [
            "aten::mm.default",
            "aten::nonexistent_op.default",  # This should be skipped
            "aten::addmm.default",
        ]
    )
    assert len(resolved) == 2  # Only 2 valid ops
    assert resolved[0] is torch.ops.aten.mm.default
    assert resolved[1] is torch.ops.aten.addmm.default