"docs/features/tool_calling.md" did not exist on "5a5e29de88826a1d3d7aa4f1d621067401999e98"
test_config.py 15 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import copy
4
import logging
5
from contextlib import nullcontext
6
from unittest.mock import patch
7

8
import pytest
9
from pydantic import ValidationError
10
11

from vllm.compilation.counter import compilation_counter
12
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
13
from vllm.config import CompilationConfig, CUDAGraphMode, VllmConfig
14
from vllm.config.compilation import CompilationMode, PassConfig
15
from vllm.engine.arg_utils import EngineArgs
16
from vllm.logger import _print_warning_once
17
from vllm.platforms import current_platform
18
from vllm.utils.torch_utils import _is_torch_equal_or_newer
19

20
21
22
# This import automatically registers `torch.ops.silly.attention`
from . import silly_attention  # noqa: F401

23
24

def test_version():
25
    # Test the version comparison logic using the private function
26
27
28
29
30
    assert _is_torch_equal_or_newer("2.8.0.dev20250624+cu128", "2.8.0.dev")
    assert _is_torch_equal_or_newer("2.8.0a0+gitc82a174", "2.8.0.dev")
    assert _is_torch_equal_or_newer("2.8.0", "2.8.0.dev")
    assert _is_torch_equal_or_newer("2.8.1", "2.8.0.dev")
    assert not _is_torch_equal_or_newer("2.7.1", "2.8.0.dev")
31
32


33
34
35
36
37
38
39
40
41
42
43
44
45
46
def test_copy_pass():
    vllm_config = VllmConfig()
    inductor_pass = FixFunctionalizationPass(vllm_config)
    copied_inductor_pass = copy.deepcopy(inductor_pass)
    assert (
        copied_inductor_pass.compilation_config.use_inductor_graph_partition
        == vllm_config.compilation_config.use_inductor_graph_partition
    )
    assert (
        copied_inductor_pass.compilation_config.splitting_ops
        == vllm_config.compilation_config.splitting_ops
    )


47
48
49
50
51
52
53
54
def test_custom_op():
    # proper syntax
    _ = CompilationConfig(custom_ops=["+quant_fp8", "-silu_and_mul"])

    with pytest.raises(ValueError, match="Invalid syntax '"):
        _ = CompilationConfig(custom_ops=["quant_fp8"])


55
56
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@pytest.mark.forked
57
58
59
60
61
# NB: We don't test VLLM_DISABLE_COMPILE_CACHE=0 because that depends
# on the state of the cache directory on the current machine, which
# may be influenced by other tests.
@pytest.mark.parametrize("val", ["1"])
def test_VLLM_DISABLE_COMPILE_CACHE(vllm_runner, monkeypatch, val):
62
    # Disable multiprocessing so that the counter is in the same process
63
64
    monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
    monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", val)
65
66

    compilation_config = {
67
        "cudagraph_mode": CUDAGraphMode.NONE,  # speed things up a bit
68
69
    }
    with (
70
71
72
73
74
75
76
77
78
79
        compilation_counter.expect(
            num_cache_entries_updated=0, num_compiled_artifacts_saved=0
        ),
        # loading the model causes compilation (if enabled) to happen
        vllm_runner(
            "facebook/opt-125m",
            compilation_config=compilation_config,
            gpu_memory_utilization=0.4,
        ) as _,
    ):
80
81
82
        pass


83
84
# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@pytest.mark.forked
85
86
87
88
89
90
91
92
93
94
95
96
@pytest.mark.parametrize(
    "cudagraph_mode,num_cudagraph_captured",
    [
        (CUDAGraphMode.NONE, 0),
        (CUDAGraphMode.FULL_DECODE_ONLY, 1),
        (CUDAGraphMode.PIECEWISE, 13),
        (CUDAGraphMode.FULL_AND_PIECEWISE, 14),
    ],
)
def test_use_cudagraphs(
    vllm_runner, monkeypatch, cudagraph_mode, num_cudagraph_captured
):
97
    # Disable multiprocessing so that the counter is in the same process
98
    monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
99
100
101

    compilation_config = {
        "cudagraph_capture_sizes": [100],
102
        "cudagraph_mode": cudagraph_mode,
103
    }
104
    num_gpu_runner_capture_triggers = 1 if cudagraph_mode != CUDAGraphMode.NONE else 0
105
    with (
106
107
        compilation_counter.expect(
            num_graphs_seen=1,
108
109
            num_gpu_runner_capture_triggers=num_gpu_runner_capture_triggers,
            num_cudagraph_captured=num_cudagraph_captured,
110
111
112
113
114
115
116
117
        ),
        # loading the model causes compilation (if enabled) to happen
        vllm_runner(
            "facebook/opt-125m",
            compilation_config=compilation_config,
            gpu_memory_utilization=0.4,
        ) as _,
    ):
118
        pass
119
120
121
122


# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@pytest.mark.forked
123
def test_stock_torch_compile(vllm_runner, monkeypatch):
124
    # Disable multiprocessing so that the counter is in the same process
125
    monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
126
127

    with (
128
        compilation_counter.expect(stock_torch_compile_count=1),
129
130
131
        # loading the model causes compilation (if enabled) to happen
        vllm_runner(
            "facebook/opt-125m",
132
            compilation_config={"mode": CompilationMode.STOCK_TORCH_COMPILE},
133
134
135
            gpu_memory_utilization=0.4,
        ) as _,
    ):
136
137
138
139
140
141
142
        pass


# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@pytest.mark.forked
def test_no_compilation(vllm_runner, monkeypatch):
    # Disable multiprocessing so that the counter is in the same process
143
    monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
144
    with (
145
        compilation_counter.expect(num_graphs_seen=0, stock_torch_compile_count=0),
146
147
148
        # loading the model causes compilation (if enabled) to happen
        vllm_runner(
            "facebook/opt-125m",
149
            compilation_config={"mode": CompilationMode.NONE},
150
151
152
            gpu_memory_utilization=0.4,
        ) as _,
    ):
153
154
155
156
157
158
159
        pass


# forked needed to workaround https://github.com/vllm-project/vllm/issues/21073
@pytest.mark.forked
def test_enforce_eager(vllm_runner, monkeypatch):
    # Disable multiprocessing so that the counter is in the same process
160
    monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
161
162

    with (
163
        compilation_counter.expect(num_graphs_seen=0, stock_torch_compile_count=0),
164
165
166
167
168
        # loading the model causes compilation (if enabled) to happen
        vllm_runner(
            "facebook/opt-125m", enforce_eager=True, gpu_memory_utilization=0.4
        ) as _,
    ):
169
        pass
170
171
172
173
174


def test_splitting_ops_dynamic():
    # Default config
    config = VllmConfig()
175
176
    # Default V1 config leaves cudagraph mode unset; splitting ops are only
    # populated when the engine decides to use piecewise compilation.
177
178
    assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE
    assert config.compilation_config.splitting_ops_contain_attention()
179
180

    # When use_inductor_graph_partition=True
181
182
183
184
185
    config = VllmConfig(
        compilation_config=CompilationConfig(
            mode=CompilationMode.VLLM_COMPILE,
            use_inductor_graph_partition=True,
            splitting_ops=["vllm::unified_attention"],
186
        )
187
188
189
190
    )
    # with inductor partition we use splitting_ops directly for
    # partition rules
    assert config.compilation_config.splitting_ops == ["vllm::unified_attention"]
191

192
    # When attn_fusion pass enabled.
193
194
    config = VllmConfig(
        compilation_config=CompilationConfig(
195
            mode=CompilationMode.VLLM_COMPILE,
196
            pass_config=PassConfig(fuse_attn_quant=True, eliminate_noops=True),
197
198
199
200
            custom_ops=["+quant_fp8"],
            cudagraph_mode=CUDAGraphMode.PIECEWISE,
        )
    )
201
202
203
    assert config.compilation_config.splitting_ops == []
    # cudagraph mode also fall back to FULL
    assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL
204

205
206
207
    # splitting_ops can not contain attention ops when attn_fusion
    # pass enabled.
    with pytest.raises(ValidationError):
208
209
        config = VllmConfig(
            compilation_config=CompilationConfig(
210
                mode=CompilationMode.VLLM_COMPILE,
211
                pass_config=PassConfig(fuse_attn_quant=True, eliminate_noops=True),
212
213
                custom_ops=["+quant_fp8"],
                cudagraph_mode=CUDAGraphMode.PIECEWISE,
214
215
                # work around for accessing all attntion ops
                splitting_ops=CompilationConfig()._attention_ops,
216
217
            )
        )
218
219
220
221
222
223

    # When both use_inductor_graph_partition and attn_fusion pass enabled.
    config = VllmConfig(
        compilation_config=CompilationConfig(
            mode=CompilationMode.VLLM_COMPILE,
            use_inductor_graph_partition=True,
224
            pass_config=PassConfig(fuse_attn_quant=True, eliminate_noops=True),
225
226
227
228
229
230
231
            custom_ops=["+quant_fp8"],
            cudagraph_mode=CUDAGraphMode.PIECEWISE,
        )
    )
    # With inductor graph partition, attn_fusion and splitting_ops
    # work together. Default splitting_ops include attention ops.
    assert config.compilation_config.splitting_ops_contain_attention()
232
    # fuse_attn_quant is directly supported under
233
234
235
    # use_inductor_graph_partition=True, and cudagraph_mode
    # is unchanged.
    assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE
236
237


238
def test_should_split():
239
240
    import torch

241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
    from vllm.compilation.partition_rules import should_split

    graph = torch.fx.Graph()
    node = torch.fx.Node(
        graph=graph,
        name="dummy_node",
        op="call_function",
        target=torch.ops.aten.add.default,
        args=(),
        kwargs={},
    )

    # supports OpOverloadPacket
    splitting_ops = ["aten::add"]
    assert should_split(node, splitting_ops)

    # supports OpOverload
    splitting_ops = ["aten::add.default"]
    assert should_split(node, splitting_ops)

    # supports OpOverload
    splitting_ops = ["aten::add.Tensor"]
    assert not should_split(node, splitting_ops)

    q, k, v, out = [torch.randn(1)] * 4

    # supports custom ops as OpOverloadPacket
    node = torch.fx.Node(
        graph=graph,
        name="dummy_node",
        op="call_function",
        target=torch.ops.silly.attention,
        args=(q, k, v, out),
        kwargs={},
    )

    splitting_ops = ["silly::attention"]
    assert should_split(node, splitting_ops)

    # supports custom ops as OpOverload
    node = torch.fx.Node(
        graph=graph,
        name="dummy_node",
        op="call_function",
        target=torch.ops.silly.attention.default,
        args=(q, k, v, out),
        kwargs={},
    )

    splitting_ops = ["silly::attention"]
    assert should_split(node, splitting_ops)

    splitting_ops = ["silly::attention.default"]
    assert should_split(node, splitting_ops)
295
296
297
298
299
300
301
302
303
304
305


@pytest.mark.skipif(
    not current_platform.support_static_graph_mode(),
    reason="Skip if not cudagraph mode supported",
)
@pytest.mark.parametrize(
    (
        "cudagraph_capture_sizes",
        "max_cudagraph_capture_size",
        "tp_size",
306
        "enable_sp",
307
        "max_num_batched_tokens",
308
        "cudagraph_mode",
309
310
311
        "expected_max_size",
    ),
    [
312
313
314
315
316
317
318
319
320
321
322
323
324
325
        (None, None, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 256),
        ([1, 2, 4], 4, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 4),
        (
            [1, 2, 4],
            8,
            1,
            False,
            2048,
            CUDAGraphMode.FULL_AND_PIECEWISE,
            ValidationError,
        ),
        ([1, 256], None, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 256),
        ([], None, 1, False, 2048, CUDAGraphMode.NONE, 0),
        (None, 0, 1, False, 2048, CUDAGraphMode.NONE, 0),
326
        # truncated to nearest multiple of 8 or 16
327
328
329
330
331
332
333
        (None, 257, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 256),
        # max from list
        ([1, 2, 4, 15], None, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 15),
        # filtered out 15 due to SP
        ([1, 2, 4, 15], None, 2, True, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 4),
        # limited by the max_tokens
        ([1, 2, 4, 15], None, 1, False, 8, CUDAGraphMode.FULL_AND_PIECEWISE, 4),
334
        # the list should contain at least 1 element when use cudagraph
335
        ([], None, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, ValidationError),
336
        # the max capturing size should be >= 1 when use cudagraph
337
        (None, 0, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, ValidationError),
338
339
340
341
342
343
    ],
)
def test_cudagraph_sizes_post_init(
    cudagraph_capture_sizes,
    max_cudagraph_capture_size,
    tp_size,
344
    enable_sp,
345
    max_num_batched_tokens,
346
    cudagraph_mode,
347
348
349
    expected_max_size,
):
    ctx = nullcontext()
350
    if expected_max_size == ValidationError:
351
352
        ctx = pytest.raises(expected_max_size)

353
354
355
356
    with (
        ctx,
        patch("vllm.config.parallel.cuda_device_count_stateless", return_value=tp_size),
    ):
357
358
359
        compilation_config = CompilationConfig(
            cudagraph_capture_sizes=cudagraph_capture_sizes,
            max_cudagraph_capture_size=max_cudagraph_capture_size,
360
361
362
363
364
365
            pass_config=PassConfig(
                enable_sp=enable_sp,
                fuse_norm_quant=True,
                fuse_act_quant=True,
                eliminate_noops=True,
            ),
366
367
368
369
370
            cudagraph_mode=cudagraph_mode,
        )
        engine_args = EngineArgs(
            model="facebook/opt-125m",
            tensor_parallel_size=tp_size,
371
            max_num_seqs=min(max_num_batched_tokens, 128),
372
373
374
375
376
            max_num_batched_tokens=max_num_batched_tokens,
            compilation_config=compilation_config,
        )
        vllm_config = engine_args.create_engine_config()

377
378
379
380
        assert (
            vllm_config.compilation_config.max_cudagraph_capture_size
            == expected_max_size
        )
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430


def test_pass_config_deprecation(caplog_vllm):
    caplog_vllm.set_level(logging.WARNING)

    # Clear cache to ensure warnings are re-issued
    _print_warning_once.cache_clear()

    # Test enable_fusion -> fuse_norm_quant, fuse_act_quant
    caplog_vllm.clear()
    config = PassConfig(enable_fusion=True)
    assert "enable_fusion is deprecated" in caplog_vllm.text
    assert config.fuse_norm_quant is True
    assert config.fuse_act_quant is True
    assert config.enable_fusion is None

    # Test enable_attn_fusion -> fuse_attn_quant
    caplog_vllm.clear()
    config = PassConfig(enable_attn_fusion=True)
    assert "enable_attn_fusion is deprecated" in caplog_vllm.text
    assert config.fuse_attn_quant is True
    assert config.enable_attn_fusion is None

    # Test enable_noop -> eliminate_noops
    caplog_vllm.clear()
    config = PassConfig(enable_noop=True)
    assert "enable_noop is deprecated" in caplog_vllm.text
    assert config.eliminate_noops is True
    assert config.enable_noop is None

    # Test enable_sequence_parallelism -> enable_sp
    caplog_vllm.clear()
    config = PassConfig(enable_sequence_parallelism=True)
    assert "enable_sequence_parallelism is deprecated" in caplog_vllm.text
    assert config.enable_sp is True
    assert config.enable_sequence_parallelism is None

    # Test enable_async_tp -> fuse_gemm_comms
    caplog_vllm.clear()
    config = PassConfig(enable_async_tp=True)
    assert "enable_async_tp is deprecated" in caplog_vllm.text
    assert config.fuse_gemm_comms is True
    assert config.enable_async_tp is None

    # Test enable_fi_allreduce_fusion -> fuse_allreduce_rms
    caplog_vllm.clear()
    config = PassConfig(enable_fi_allreduce_fusion=True)
    assert "enable_fi_allreduce_fusion is deprecated" in caplog_vllm.text
    assert config.fuse_allreduce_rms is True
    assert config.enable_fi_allreduce_fusion is None