"vllm/model_executor/models/bailing_moe.py" did not exist on "a99b9f7dee0ad261284cbcd823f5b37381d15ac1"
test_config.py 11.5 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import copy
4
from contextlib import nullcontext
5

6
7
8
import pytest

from vllm.compilation.counter import compilation_counter
9
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
10
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
11
from vllm.config.compilation import CompilationMode
12
13
from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform
14
from vllm.utils.torch_utils import _is_torch_equal_or_newer, is_torch_equal_or_newer
15
16
17


def test_version():
18
    # Test the version comparison logic using the private function
19
20
21
22
23
    assert _is_torch_equal_or_newer("2.8.0.dev20250624+cu128", "2.8.0.dev")
    assert _is_torch_equal_or_newer("2.8.0a0+gitc82a174", "2.8.0.dev")
    assert _is_torch_equal_or_newer("2.8.0", "2.8.0.dev")
    assert _is_torch_equal_or_newer("2.8.1", "2.8.0.dev")
    assert not _is_torch_equal_or_newer("2.7.1", "2.8.0.dev")
24
25


26
def test_use_cudagraphs_dynamic():
27
    vllm_config = VllmConfig()
28
29
30
    # Default V1 configuration now starts without cudagraphs enabled; the
    # engine decides when to capture based on runtime settings instead of a
    # blanket default.
31
32
33
    assert vllm_config.compilation_config.use_cudagraph


34
35
36
37
38
39
40
41
42
43
44
45
46
47
def test_copy_pass():
    vllm_config = VllmConfig()
    inductor_pass = FixFunctionalizationPass(vllm_config)
    copied_inductor_pass = copy.deepcopy(inductor_pass)
    assert (
        copied_inductor_pass.compilation_config.use_inductor_graph_partition
        == vllm_config.compilation_config.use_inductor_graph_partition
    )
    assert (
        copied_inductor_pass.compilation_config.splitting_ops
        == vllm_config.compilation_config.splitting_ops
    )


48
49
50
51
52
53
54
55
def test_custom_op():
    # proper syntax
    _ = CompilationConfig(custom_ops=["+quant_fp8", "-silu_and_mul"])

    with pytest.raises(ValueError, match="Invalid syntax '"):
        _ = CompilationConfig(custom_ops=["quant_fp8"])


56
57
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@pytest.mark.forked
58
59
60
61
62
# NB: We don't test VLLM_DISABLE_COMPILE_CACHE=0 because that depends
# on the state of the cache directory on the current machine, which
# may be influenced by other tests.
@pytest.mark.parametrize("val", ["1"])
def test_VLLM_DISABLE_COMPILE_CACHE(vllm_runner, monkeypatch, val):
63
    # Disable multiprocessing so that the counter is in the same process
64
65
    monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
    monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", val)
66
67
68
69
70

    compilation_config = {
        "use_cudagraph": False,  # speed things up a bit
    }
    with (
71
72
73
74
75
76
77
78
79
80
        compilation_counter.expect(
            num_cache_entries_updated=0, num_compiled_artifacts_saved=0
        ),
        # loading the model causes compilation (if enabled) to happen
        vllm_runner(
            "facebook/opt-125m",
            compilation_config=compilation_config,
            gpu_memory_utilization=0.4,
        ) as _,
    ):
81
82
83
        pass


84
85
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@pytest.mark.forked
86
@pytest.mark.parametrize("enabled", [True, False])
87
88
def test_use_cudagraphs(vllm_runner, monkeypatch, enabled):
    # Disable multiprocessing so that the counter is in the same process
89
    monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
90
91
92
93
94
95

    compilation_config = {
        "cudagraph_capture_sizes": [100],
        "use_cudagraph": enabled,
    }
    with (
96
97
98
99
100
101
102
103
104
105
106
107
        compilation_counter.expect(
            num_graphs_seen=1,
            num_gpu_runner_capture_triggers=1 if enabled else 0,
            num_cudagraph_captured=13 if enabled else 0,
        ),
        # loading the model causes compilation (if enabled) to happen
        vllm_runner(
            "facebook/opt-125m",
            compilation_config=compilation_config,
            gpu_memory_utilization=0.4,
        ) as _,
    ):
108
        pass
109
110
111
112


# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@pytest.mark.forked
113
def test_stock_torch_compile(vllm_runner, monkeypatch):
114
    # Disable multiprocessing so that the counter is in the same process
115
    monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
116
117

    with (
118
        compilation_counter.expect(stock_torch_compile_count=1),
119
120
121
        # loading the model causes compilation (if enabled) to happen
        vllm_runner(
            "facebook/opt-125m",
122
            compilation_config={"mode": CompilationMode.STOCK_TORCH_COMPILE},
123
124
125
            gpu_memory_utilization=0.4,
        ) as _,
    ):
126
127
128
129
130
131
132
        pass


# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@pytest.mark.forked
def test_no_compilation(vllm_runner, monkeypatch):
    # Disable multiprocessing so that the counter is in the same process
133
    monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
134
    with (
135
        compilation_counter.expect(num_graphs_seen=0, stock_torch_compile_count=0),
136
137
138
        # loading the model causes compilation (if enabled) to happen
        vllm_runner(
            "facebook/opt-125m",
139
            compilation_config={"mode": CompilationMode.NONE},
140
141
142
            gpu_memory_utilization=0.4,
        ) as _,
    ):
143
144
145
146
147
148
149
        pass


# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@pytest.mark.forked
def test_enforce_eager(vllm_runner, monkeypatch):
    # Disable multiprocessing so that the counter is in the same process
150
    monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
151
152

    with (
153
        compilation_counter.expect(num_graphs_seen=0, stock_torch_compile_count=0),
154
155
156
157
158
        # loading the model causes compilation (if enabled) to happen
        vllm_runner(
            "facebook/opt-125m", enforce_eager=True, gpu_memory_utilization=0.4
        ) as _,
    ):
159
        pass
160
161
162
163
164


def test_splitting_ops_dynamic():
    # Default config
    config = VllmConfig()
165
166
167
168
    # Default V1 config leaves cudagraph mode unset; splitting ops are only
    # populated when the engine decides to use piecewise compilation.
    assert config.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
    assert not config.compilation_config.splitting_ops_contain_attention()
169
170

    # When use_inductor_graph_partition=True
171
    if is_torch_equal_or_newer("2.9.0.dev"):
172
173
        config = VllmConfig(
            compilation_config=CompilationConfig(
174
                mode=CompilationMode.VLLM_COMPILE,
175
176
                use_inductor_graph_partition=True,
                splitting_ops=["vllm::unified_attention"],
177
178
            )
        )
179
180
181
        # with inductor partition we use splitting_ops directly for
        # partition rules
        assert config.compilation_config.splitting_ops == ["vllm::unified_attention"]
182

183
    # When attn_fusion pass enabled, splitting_ops now default to attention ops.
184
185
    config = VllmConfig(
        compilation_config=CompilationConfig(
186
            mode=CompilationMode.VLLM_COMPILE,
187
188
189
190
191
            pass_config={"enable_attn_fusion": True, "enable_noop": True},
            custom_ops=["+quant_fp8"],
            cudagraph_mode=CUDAGraphMode.PIECEWISE,
        )
    )
192
193
194
195
    # With the new simplified logic, attention fusion works with splitting_ops
    assert config.compilation_config.splitting_ops_contain_attention()
    # cudagraph mode remains PIECEWISE
    assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE
196
197

    # When both use_inductor_graph_partition and attn_fusion pass enabled.
198
    if is_torch_equal_or_newer("2.9.0.dev"):
199
200
        config = VllmConfig(
            compilation_config=CompilationConfig(
201
                mode=CompilationMode.VLLM_COMPILE,
202
203
204
205
206
207
                use_inductor_graph_partition=True,
                pass_config={"enable_attn_fusion": True, "enable_noop": True},
                custom_ops=["+quant_fp8"],
                cudagraph_mode=CUDAGraphMode.PIECEWISE,
            )
        )
208
209
210
211
        # With inductor graph partition, attn_fusion and splitting_ops
        # work together. Default splitting_ops include attention ops.
        assert config.compilation_config.splitting_ops_contain_attention()
        # enable_attn_fusion is directly supported under
212
213
        # use_inductor_graph_partition=True, and cudagraph_mode
        # is unchanged.
214
        assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238


def test_resolve_operator_overload():
    import torch

    from vllm.compilation.partition_rules import resolve_defined_ops

    # Test valid operator names
    resolved = resolve_defined_ops(["aten::mm.default", "aten::addmm.default"])
    assert len(resolved) == 2
    assert resolved[0] is torch.ops.aten.mm.default
    assert resolved[1] is torch.ops.aten.addmm.default

    # Test that invalid operators are skipped (not raising exceptions)
    resolved = resolve_defined_ops(
        [
            "aten::mm.default",
            "aten::nonexistent_op.default",  # This should be skipped
            "aten::addmm.default",
        ]
    )
    assert len(resolved) == 2  # Only 2 valid ops
    assert resolved[0] is torch.ops.aten.mm.default
    assert resolved[1] is torch.ops.aten.addmm.default
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308


@pytest.mark.skipif(
    not current_platform.support_static_graph_mode(),
    reason="Skip if not cudagraph mode supported",
)
@pytest.mark.parametrize(
    (
        "cudagraph_capture_sizes",
        "max_cudagraph_capture_size",
        "tp_size",
        "enable_sequence_parallelism",
        "max_num_batched_tokens",
        "use_cudagraph",
        "expected_max_size",
    ),
    [
        (None, None, 1, False, 2048, True, 512),
        ([1, 2, 4], 4, 1, False, 2048, True, 4),
        ([1, 2, 4], 8, 1, False, 2048, True, RuntimeError),
        ([1, 256], None, 1, False, 2048, 256),
        ([], None, 1, False, 2048, False, 0),
        (None, 0, 1, False, 2048, False, 0),
        # truncated to nearest multiple of 8 or 16
        (None, 257, 1, False, 2048, True, 256),
        ([1, 2, 4, 15], None, 1, False, 2048, True, 15),  # max from list
        ([1, 2, 4, 15], None, 2, True, 2048, True, 4),  # filtered out 15 due to SP
        ([1, 2, 4, 15], None, 1, False, 8, True, 4),  # limited by the max_tokens
        # the list should contain at least 1 element when use cudagraph
        ([], None, 1, False, 2048, True, RuntimeError),
        # the max capturing size should be >= 1 when use cudagraph
        (None, 0, 1, False, 2048, True, RuntimeError),
    ],
)
def test_cudagraph_sizes_post_init(
    cudagraph_capture_sizes,
    max_cudagraph_capture_size,
    tp_size,
    enable_sequence_parallelism,
    max_num_batched_tokens,
    use_cudagraph,
    expected_max_size,
):
    ctx = nullcontext()
    if isinstance(expected_max_size, Exception):
        ctx = pytest.raises(expected_max_size)

    cudagraph_mode = CUDAGraphMode.PIECEWISE if use_cudagraph else CUDAGraphMode.NONE
    with ctx:
        compilation_config = CompilationConfig(
            cudagraph_capture_sizes=cudagraph_capture_sizes,
            max_cudagraph_capture_size=max_cudagraph_capture_size,
            pass_config={
                "enable_sequence_parallelism": enable_sequence_parallelism,
                "enable_fusion": True,
                "enable_noop": True,
            },
            cudagraph_mode=cudagraph_mode,
        )
        engine_args = EngineArgs(
            model="facebook/opt-125m",
            tensor_parallel_size=tp_size,
            max_num_batched_tokens=max_num_batched_tokens,
            compilation_config=compilation_config,
        )
        vllm_config = engine_args.create_engine_config()

    assert (
        vllm_config.compilation_config.max_cudagraph_capture_size == expected_max_size
    )