"vllm/vscode:/vscode.git/clone" did not exist on "1aa13615103c2ea47e36710a9b2e17dfe1909143"
test_config.py 7.11 KB
Newer Older
1
2
3
4
5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest

from vllm.compilation.counter import compilation_counter
6
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
7
8
9
10
from vllm.utils import _is_torch_equal_or_newer


def test_version():
11
12
13
14
15
    assert _is_torch_equal_or_newer("2.8.0.dev20250624+cu128", "2.8.0.dev")
    assert _is_torch_equal_or_newer("2.8.0a0+gitc82a174", "2.8.0.dev")
    assert _is_torch_equal_or_newer("2.8.0", "2.8.0.dev")
    assert _is_torch_equal_or_newer("2.8.1", "2.8.0.dev")
    assert not _is_torch_equal_or_newer("2.7.1", "2.8.0.dev")
16
17


18
def test_use_cudagraphs_dynamic():
19
20
21
22
    vllm_config = VllmConfig()
    assert vllm_config.compilation_config.use_cudagraph


23
24
25
26
27
28
29
30
def test_custom_op():
    # proper syntax
    _ = CompilationConfig(custom_ops=["+quant_fp8", "-silu_and_mul"])

    with pytest.raises(ValueError, match="Invalid syntax '"):
        _ = CompilationConfig(custom_ops=["quant_fp8"])


31
32
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@pytest.mark.forked
33
34
35
36
37
# NB: We don't test VLLM_DISABLE_COMPILE_CACHE=0 because that depends
# on the state of the cache directory on the current machine, which
# may be influenced by other tests.
@pytest.mark.parametrize("val", ["1"])
def test_VLLM_DISABLE_COMPILE_CACHE(vllm_runner, monkeypatch, val):
38
    # Disable multiprocessing so that the counter is in the same process
39
40
    monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
    monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", val)
41
42
43
44
45

    compilation_config = {
        "use_cudagraph": False,  # speed things up a bit
    }
    with (
46
47
48
49
50
51
52
53
54
55
        compilation_counter.expect(
            num_cache_entries_updated=0, num_compiled_artifacts_saved=0
        ),
        # loading the model causes compilation (if enabled) to happen
        vllm_runner(
            "facebook/opt-125m",
            compilation_config=compilation_config,
            gpu_memory_utilization=0.4,
        ) as _,
    ):
56
57
58
        pass


59
60
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@pytest.mark.forked
61
@pytest.mark.parametrize("enabled", [True, False])
62
63
def test_use_cudagraphs(vllm_runner, monkeypatch, enabled):
    # Disable multiprocessing so that the counter is in the same process
64
    monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
65
66
67
68
69
70

    compilation_config = {
        "cudagraph_capture_sizes": [100],
        "use_cudagraph": enabled,
    }
    with (
71
72
73
74
75
76
77
78
79
80
81
82
        compilation_counter.expect(
            num_graphs_seen=1,
            num_gpu_runner_capture_triggers=1 if enabled else 0,
            num_cudagraph_captured=13 if enabled else 0,
        ),
        # loading the model causes compilation (if enabled) to happen
        vllm_runner(
            "facebook/opt-125m",
            compilation_config=compilation_config,
            gpu_memory_utilization=0.4,
        ) as _,
    ):
83
        pass
84
85
86
87
88
89


# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@pytest.mark.forked
def test_dynamo_as_is(vllm_runner, monkeypatch):
    # Disable multiprocessing so that the counter is in the same process
90
    monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
91
92

    with (
93
94
95
96
97
98
99
100
        compilation_counter.expect(dynamo_as_is_count=1),
        # loading the model causes compilation (if enabled) to happen
        vllm_runner(
            "facebook/opt-125m",
            compilation_config={"level": 1},
            gpu_memory_utilization=0.4,
        ) as _,
    ):
101
102
103
104
105
106
107
        pass


# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@pytest.mark.forked
def test_no_compilation(vllm_runner, monkeypatch):
    # Disable multiprocessing so that the counter is in the same process
108
    monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
109
    with (
110
111
112
113
114
115
116
117
        compilation_counter.expect(num_graphs_seen=0, dynamo_as_is_count=0),
        # loading the model causes compilation (if enabled) to happen
        vllm_runner(
            "facebook/opt-125m",
            compilation_config={"level": 0},
            gpu_memory_utilization=0.4,
        ) as _,
    ):
118
119
120
121
122
123
124
        pass


# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@pytest.mark.forked
def test_enforce_eager(vllm_runner, monkeypatch):
    # Disable multiprocessing so that the counter is in the same process
125
    monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
126
127

    with (
128
129
130
131
132
133
        compilation_counter.expect(num_graphs_seen=0, dynamo_as_is_count=0),
        # loading the model causes compilation (if enabled) to happen
        vllm_runner(
            "facebook/opt-125m", enforce_eager=True, gpu_memory_utilization=0.4
        ) as _,
    ):
134
        pass
135
136
137
138
139


def test_splitting_ops_dynamic():
    # Default config
    config = VllmConfig()
140
    assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE
141
142
143
    assert config.compilation_config.splitting_ops_contain_attention()

    # When use_inductor_graph_partition=True
144
    if _is_torch_equal_or_newer("2.9.0.dev"):
145
146
        # inductor graph partition is only available in PyTorch 2.9+.
        # this is a fast config check so we are not using pytest.skip.
147
148
149
150
151
        config = VllmConfig(
            compilation_config=CompilationConfig(
                use_inductor_graph_partition=True, splitting_ops=["silly_attention"]
            )
        )
152
153
154
155
        # should ignore splitting_ops
        assert config.compilation_config.splitting_ops == []

    # When attn_fusion pass enabled.
156
157
158
159
160
161
162
    config = VllmConfig(
        compilation_config=CompilationConfig(
            pass_config={"enable_attn_fusion": True, "enable_noop": True},
            custom_ops=["+quant_fp8"],
            cudagraph_mode=CUDAGraphMode.PIECEWISE,
        )
    )
163
164
    assert config.compilation_config.splitting_ops == []
    # cudagraph mode also fall back to FULL
165
    assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL
166
167
168
169

    # splitting_ops can not contain attention ops when attn_fusion
    # pass enabled.
    with pytest.raises(AssertionError):
170
171
172
173
174
175
176
177
178
        config = VllmConfig(
            compilation_config=CompilationConfig(
                pass_config={"enable_attn_fusion": True, "enable_noop": True},
                custom_ops=["+quant_fp8"],
                cudagraph_mode=CUDAGraphMode.PIECEWISE,
                # work around for accessing all attntion ops
                splitting_ops=CompilationConfig()._attention_ops,
            )
        )
179
180

    # When both use_inductor_graph_partition and attn_fusion pass enabled.
181
182
183
184
185
186
187
188
189
    if _is_torch_equal_or_newer("2.9.0.dev"):
        config = VllmConfig(
            compilation_config=CompilationConfig(
                use_inductor_graph_partition=True,
                pass_config={"enable_attn_fusion": True, "enable_noop": True},
                custom_ops=["+quant_fp8"],
                cudagraph_mode=CUDAGraphMode.PIECEWISE,
            )
        )
190
191
192
193
        assert config.compilation_config.splitting_ops == []
        # enable_attn_fusion is directly support under
        # use_inductor_graph_partition=True, and cudagraph_mode
        # is unchanged.
194
        assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE