test_arg_utils.py 12.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import json
5
from argparse import ArgumentError
6
7
from contextlib import nullcontext
from dataclasses import dataclass, field
8
from typing import Annotated, Literal
9

10
11
import pytest

12
from vllm.config import CompilationConfig, config
13
14
15
16
17
18
19
20
21
22
23
24
from vllm.engine.arg_utils import (
    EngineArgs,
    contains_type,
    get_kwargs,
    get_type,
    get_type_hints,
    is_not_builtin,
    is_type,
    literal_to_kwargs,
    optional_type,
    parse_type,
)
25
from vllm.utils.argparse_utils import FlexibleArgumentParser
26
27


28
29
30
31
32
33
34
35
36
@pytest.mark.parametrize(
    ("type", "value", "expected"),
    [
        (int, "42", 42),
        (float, "3.14", 3.14),
        (str, "Hello World!", "Hello World!"),
        (json.loads, '{"foo":1,"bar":2}', {"foo": 1, "bar": 2}),
    ],
)
37
38
def test_parse_type(type, value, expected):
    parse_type_func = parse_type(type)
39
    assert parse_type_func(value) == expected
40
41
42
43
44
45


def test_optional_type():
    optional_type_func = optional_type(int)
    assert optional_type_func("None") is None
    assert optional_type_func("42") == 42
46
47


48
49
50
51
52
53
54
55
56
57
@pytest.mark.parametrize(
    ("type_hint", "type", "expected"),
    [
        (int, int, True),
        (int, float, False),
        (list[int], list, True),
        (list[int], tuple, False),
        (Literal[0, 1], Literal, True),
    ],
)
58
59
60
61
def test_is_type(type_hint, type, expected):
    assert is_type(type_hint, type) == expected


62
63
64
65
66
67
68
69
70
71
72
73
@pytest.mark.parametrize(
    ("type_hints", "type", "expected"),
    [
        ({float, int}, int, True),
        ({int, tuple}, int, True),
        ({int, tuple[int]}, int, True),
        ({int, tuple[int, ...]}, int, True),
        ({int, tuple[int]}, float, False),
        ({int, tuple[int, ...]}, float, False),
        ({str, Literal["x", "y"]}, Literal, True),
    ],
)
74
75
76
77
def test_contains_type(type_hints, type, expected):
    assert contains_type(type_hints, type) == expected


78
79
80
81
82
83
84
85
@pytest.mark.parametrize(
    ("type_hints", "type", "expected"),
    [
        ({int, float}, int, int),
        ({int, float}, str, None),
        ({str, Literal["x", "y"]}, Literal, Literal["x", "y"]),
    ],
)
86
87
88
89
def test_get_type(type_hints, type, expected):
    assert get_type(type_hints, type) == expected


90
91
92
93
94
95
96
97
@pytest.mark.parametrize(
    ("type_hints", "expected"),
    [
        ({Literal[1, 2]}, {"type": int, "choices": [1, 2]}),
        ({str, Literal["x", "y"]}, {"type": str, "metavar": ["x", "y"]}),
        ({Literal[1, "a"]}, Exception),
    ],
)
98
99
100
101
102
103
104
105
def test_literal_to_kwargs(type_hints, expected):
    context = nullcontext()
    if expected is Exception:
        context = pytest.raises(expected)
    with context:
        assert literal_to_kwargs(type_hints) == expected


106
107
@config
@dataclass
108
109
110
111
112
113
114
115
class NestedConfig:
    field: int = 1
    """field"""


@config
@dataclass
class DummyConfig:
116
117
    regular_bool: bool = True
    """Regular bool with default True"""
118
    optional_bool: bool | None = None
119
    """Optional bool with default None"""
120
    optional_literal: Literal["x", "y"] | None = None
121
122
    """Optional literal with default None"""
    tuple_n: tuple[int, ...] = field(default_factory=lambda: (1, 2, 3))
123
    """Tuple with variable length"""
124
    tuple_2: tuple[int, int] = field(default_factory=lambda: (1, 2))
125
    """Tuple with fixed length"""
126
    list_n: list[int] = field(default_factory=lambda: [1, 2, 3])
127
128
129
    """List with variable length"""
    list_literal: list[Literal[1, 2]] = field(default_factory=list)
    """List with literal choices"""
130
    list_union: list[str | type[object]] = field(default_factory=list)
131
    """List with union type"""
132
133
    set_n: set[int] = field(default_factory=lambda: {1, 2, 3})
    """Set with variable length"""
134
135
    literal_literal: Literal[Literal[1], Literal[2]] = 1
    """Literal of literals with default 1"""
136
137
    json_tip: dict = field(default_factory=dict)
    """Dict which will be JSON in CLI"""
138
139
    nested_config: NestedConfig = field(default_factory=NestedConfig)
    """Nested config"""
140
141


142
143
144
145
146
147
148
@pytest.mark.parametrize(
    ("type_hint", "expected"),
    [
        (int, False),
        (DummyConfig, True),
    ],
)
149
150
151
152
def test_is_not_builtin(type_hint, expected):
    assert is_not_builtin(type_hint) == expected


153
@pytest.mark.parametrize(
154
155
    ("type_hint", "expected"),
    [
156
        (Annotated[int, "annotation"], {int}),
157
158
159
        (int | None, {int, type(None)}),
        (Annotated[int | None, "annotation"], {int, type(None)}),
        (Annotated[int, "annotation"] | None, {int, type(None)}),
160
    ],
161
    ids=["Annotated", "or_None", "Annotated_or_None", "or_None_Annotated"],
162
)
163
164
165
166
def test_get_type_hints(type_hint, expected):
    assert get_type_hints(type_hint) == expected


167
def test_get_kwargs():
168
    kwargs = get_kwargs(DummyConfig)
169
170
171
172
173
174
175
176
177
178
179
180
181
    print(kwargs)

    # bools should not have their type set
    assert kwargs["regular_bool"].get("type") is None
    assert kwargs["optional_bool"].get("type") is None
    # optional literals should have None as a choice
    assert kwargs["optional_literal"]["choices"] == ["x", "y", "None"]
    # tuples should have the correct nargs
    assert kwargs["tuple_n"]["nargs"] == "+"
    assert kwargs["tuple_2"]["nargs"] == 2
    # lists should work
    assert kwargs["list_n"]["type"] is int
    assert kwargs["list_n"]["nargs"] == "+"
182
183
184
185
    # lists with literals should have the correct choices
    assert kwargs["list_literal"]["type"] is int
    assert kwargs["list_literal"]["nargs"] == "+"
    assert kwargs["list_literal"]["choices"] == [1, 2]
186
187
188
    # lists with unions should become str type.
    # If not, we cannot know which type to use for parsing
    assert kwargs["list_union"]["type"] is str
189
190
191
    # sets should work like lists
    assert kwargs["set_n"]["type"] is int
    assert kwargs["set_n"]["nargs"] == "+"
192
193
    # literals of literals should have merged choices
    assert kwargs["literal_literal"]["choices"] == [1, 2]
194
    # dict should have json tip in help
195
196
    json_tip = "Should either be a valid JSON string or JSON keys"
    assert json_tip in kwargs["json_tip"]["help"]
197
    # nested config should construct the nested config
198
    assert kwargs["nested_config"]["type"]('{"field": 2}') == NestedConfig(2)
199
200


201
202
203
204
@pytest.mark.parametrize(
    ("arg", "expected"),
    [
        (None, dict()),
205
        ('{"video": {"num_frames": 123} }', {"video": {"num_frames": 123}}),
206
207
208
        (
            '{"video": {"num_frames": 123, "fps": 1.0, "foo": "bar"}, "image": {"foo": "bar"} }',  # noqa
            {
209
210
211
212
213
214
                "video": {"num_frames": 123, "fps": 1.0, "foo": "bar"},
                "image": {"foo": "bar"},
            },
        ),
    ],
)
215
216
217
218
219
220
221
222
223
224
def test_media_io_kwargs_parser(arg, expected):
    parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
    if arg is None:
        args = parser.parse_args([])
    else:
        args = parser.parse_args(["--media-io-kwargs", arg])

    assert args.media_io_kwargs == expected


225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
@pytest.mark.parametrize(
    ("args", "expected"),
    [
        (["-O", "1"], "1"),
        (["-O", "2"], "2"),
        (["-O", "3"], "3"),
        (["-O0"], "0"),
        (["-O1"], "1"),
        (["-O2"], "2"),
        (["-O3"], "3"),
    ],
)
def test_optimization_level(args, expected):
    """
    Test space-separated optimization levels (-O 1, -O 2, -O 3) map to
    optimization_level.
    """
242
    parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
243
244
245
    parsed_args = parser.parse_args(args)
    assert parsed_args.optimization_level == expected
    assert parsed_args.compilation_config.mode is None
246
247


248
249
250
@pytest.mark.parametrize(
    ("args", "expected"),
    [
251
252
253
254
        (["-cc.mode=0"], 0),
        (["-cc.mode=1"], 1),
        (["-cc.mode=2"], 2),
        (["-cc.mode=3"], 3),
255
256
257
258
    ],
)
def test_mode_parser(args, expected):
    """
259
    Test compilation config modes (-cc.mode=int) map to compilation_config.
260
261
262
263
    """
    parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
    parsed_args = parser.parse_args(args)
    assert parsed_args.compilation_config.mode == expected
264
265


266
267
def test_compilation_config():
    parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
268

269
270
271
    # default value
    args = parser.parse_args([])
    assert args.compilation_config == CompilationConfig()
272

273
    # set to string form of a dict
274
275
    args = parser.parse_args(
        [
276
            "-cc",
277
            '{"mode": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], "backend": "eager"}',
278
279
280
        ]
    )
    assert (
281
        args.compilation_config.mode == 3
282
        and args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8]
283
        and args.compilation_config.backend == "eager"
284
    )
285

286
    # set to string form of a dict
287
288
289
    args = parser.parse_args(
        [
            "--compilation-config="
290
            '{"mode": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], '
291
            '"backend": "inductor"}',
292
293
294
        ]
    )
    assert (
295
        args.compilation_config.mode == 3
296
        and args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8]
297
        and args.compilation_config.backend == "inductor"
298
    )
299
300


301
302
303
304
def test_prefix_cache_default():
    parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
    args = parser.parse_args([])

305
    # should be None by default (depends on model).
306
    engine_args = EngineArgs.from_cli_args(args=args)
307
    assert engine_args.enable_prefix_caching is None
308
309
310
311
312
313
314
315
316
317
318
319

    # with flag to turn it on.
    args = parser.parse_args(["--enable-prefix-caching"])
    engine_args = EngineArgs.from_cli_args(args=args)
    assert engine_args.enable_prefix_caching

    # with disable flag to turn it off.
    args = parser.parse_args(["--no-enable-prefix-caching"])
    engine_args = EngineArgs.from_cli_args(args=args)
    assert not engine_args.enable_prefix_caching


320
321
322
323
324
325
326
327
328
@pytest.mark.parametrize(
    ("arg", "expected", "option"),
    [
        (None, None, "mm-processor-kwargs"),
        ("{}", {}, "mm-processor-kwargs"),
        ('{"num_crops": 4}', {"num_crops": 4}, "mm-processor-kwargs"),
        ('{"foo": {"bar": "baz"}}', {"foo": {"bar": "baz"}}, "mm-processor-kwargs"),
    ],
)
329
def test_composite_arg_parser(arg, expected, option):
330
331
332
333
    parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
    if arg is None:
        args = parser.parse_args([])
    else:
334
335
        args = parser.parse_args([f"--{option}", arg])
    assert getattr(args, option.replace("-", "_")) == expected
336
337
338
339


def test_human_readable_model_len():
    # `exit_on_error` disabled to test invalid values below
340
    parser = EngineArgs.add_cli_args(FlexibleArgumentParser(exit_on_error=False))
341
342
343
344
345
346
347
348
349
350
351
352

    args = parser.parse_args([])
    assert args.max_model_len is None

    args = parser.parse_args(["--max-model-len", "1024"])
    assert args.max_model_len == 1024

    # Lower
    args = parser.parse_args(["--max-model-len", "1m"])
    assert args.max_model_len == 1_000_000
    args = parser.parse_args(["--max-model-len", "10k"])
    assert args.max_model_len == 10_000
353
354
355
356
    args = parser.parse_args(["--max-model-len", "2g"])
    assert args.max_model_len == 2_000_000_000
    args = parser.parse_args(["--max-model-len", "2t"])
    assert args.max_model_len == 2_000_000_000_000
357
358
359

    # Capital
    args = parser.parse_args(["--max-model-len", "3K"])
360
    assert args.max_model_len == 2**10 * 3
361
362
    args = parser.parse_args(["--max-model-len", "10M"])
    assert args.max_model_len == 2**20 * 10
363
364
365
366
    args = parser.parse_args(["--max-model-len", "4G"])
    assert args.max_model_len == 2**30 * 4
    args = parser.parse_args(["--max-model-len", "4T"])
    assert args.max_model_len == 2**40 * 4
367
368
369
370
371

    # Decimal values
    args = parser.parse_args(["--max-model-len", "10.2k"])
    assert args.max_model_len == 10200
    # ..truncated to the nearest int
372
    args = parser.parse_args(["--max-model-len", "10.2123451234567k"])
373
    assert args.max_model_len == 10212
374
375
376
377
378
379
    args = parser.parse_args(["--max-model-len", "10.2123451234567m"])
    assert args.max_model_len == 10212345
    args = parser.parse_args(["--max-model-len", "10.2123451234567g"])
    assert args.max_model_len == 10212345123
    args = parser.parse_args(["--max-model-len", "10.2123451234567t"])
    assert args.max_model_len == 10212345123456
380
381

    # Invalid (do not allow decimals with binary multipliers)
382
    for invalid in ["1a", "pwd", "10.24", "1.23M", "1.22T"]:
383
        with pytest.raises(ArgumentError):
384
            parser.parse_args(["--max-model-len", invalid])