test_argparse_utils.py 13.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
# ruff: noqa
4

5
import json
6
import os
7

8
import pytest
9
import yaml
10
from transformers import AutoTokenizer
11
from pydantic import ValidationError
12

13
from vllm.tokenizers.detokenizer_utils import convert_ids_list_to_tokens
14

Cyrus Leung's avatar
Cyrus Leung committed
15
16
from vllm.utils.argparse_utils import FlexibleArgumentParser
from ..utils import flat_product
17

18

19
20
21
22
# Tests for FlexibleArgumentParser
@pytest.fixture
def parser():
    parser = FlexibleArgumentParser()
23
24
25
26
27
28
29
    parser.add_argument(
        "--image-input-type", choices=["pixel_values", "image_features"]
    )
    parser.add_argument("--model-name")
    parser.add_argument("--batch-size", type=int)
    parser.add_argument("--enable-feature", action="store_true")
    parser.add_argument("--hf-overrides", type=json.loads)
30
    parser.add_argument("-cc", "--compilation-config", type=json.loads)
31
    parser.add_argument("--optimization-level", type=int)
32
33
34
    return parser


35
36
37
@pytest.fixture
def parser_with_config():
    parser = FlexibleArgumentParser()
38
39
40
41
42
43
44
45
    parser.add_argument("serve")
    parser.add_argument("model_tag", nargs="?")
    parser.add_argument("--model", type=str)
    parser.add_argument("--served-model-name", type=str)
    parser.add_argument("--config", type=str)
    parser.add_argument("--port", type=int)
    parser.add_argument("--tensor-parallel-size", type=int)
    parser.add_argument("--trust-remote-code", action="store_true")
46
47
48
    return parser


49
def test_underscore_to_dash(parser):
50
51
    args = parser.parse_args(["--image_input_type", "pixel_values"])
    assert args.image_input_type == "pixel_values"
52
53
54


def test_mixed_usage(parser):
55
56
57
58
59
    args = parser.parse_args(
        ["--image_input_type", "image_features", "--model-name", "facebook/opt-125m"]
    )
    assert args.image_input_type == "image_features"
    assert args.model_name == "facebook/opt-125m"
60
61
62
63


def test_with_equals_sign(parser):
    args = parser.parse_args(
64
65
66
67
        ["--image_input_type=pixel_values", "--model-name=facebook/opt-125m"]
    )
    assert args.image_input_type == "pixel_values"
    assert args.model_name == "facebook/opt-125m"
68
69
70


def test_with_int_value(parser):
71
    args = parser.parse_args(["--batch_size", "32"])
72
    assert args.batch_size == 32
73
    args = parser.parse_args(["--batch-size", "32"])
74
75
76
77
    assert args.batch_size == 32


def test_with_bool_flag(parser):
78
    args = parser.parse_args(["--enable_feature"])
79
    assert args.enable_feature is True
80
    args = parser.parse_args(["--enable-feature"])
81
82
83
84
85
    assert args.enable_feature is True


def test_invalid_choice(parser):
    with pytest.raises(SystemExit):
86
        parser.parse_args(["--image_input_type", "invalid_choice"])
87
88
89


def test_missing_required_argument(parser):
90
    parser.add_argument("--required-arg", required=True)
91
92
    with pytest.raises(SystemExit):
        parser.parse_args([])
93
94


95
def test_cli_override_to_config(parser_with_config, cli_config_file):
96
97
98
    args = parser_with_config.parse_args(
        ["serve", "mymodel", "--config", cli_config_file, "--tensor-parallel-size", "3"]
    )
99
    assert args.tensor_parallel_size == 3
100
101
102
    args = parser_with_config.parse_args(
        ["serve", "mymodel", "--tensor-parallel-size", "3", "--config", cli_config_file]
    )
103
    assert args.tensor_parallel_size == 3
104
    assert args.port == 12312
105
106
107
108
109
110
111
112
113
114
115
116
    args = parser_with_config.parse_args(
        [
            "serve",
            "mymodel",
            "--tensor-parallel-size",
            "3",
            "--config",
            cli_config_file,
            "--port",
            "666",
        ]
    )
117
118
    assert args.tensor_parallel_size == 3
    assert args.port == 666
119
120


121
def test_config_args(parser_with_config, cli_config_file):
122
    args = parser_with_config.parse_args(
123
124
        ["serve", "mymodel", "--config", cli_config_file]
    )
125
    assert args.tensor_parallel_size == 2
126
    assert args.trust_remote_code
127
128
129
130


def test_config_file(parser_with_config):
    with pytest.raises(FileNotFoundError):
131
        parser_with_config.parse_args(
132
133
            ["serve", "mymodel", "--config", "test_config.yml"]
        )
134
135
136

    with pytest.raises(ValueError):
        parser_with_config.parse_args(
137
138
            ["serve", "mymodel", "--config", "./data/test_config.json"]
        )
139
140

    with pytest.raises(ValueError):
141
142
143
144
145
146
147
148
149
150
151
        parser_with_config.parse_args(
            [
                "serve",
                "mymodel",
                "--tensor-parallel-size",
                "3",
                "--config",
                "--batch-size",
                "32",
            ]
        )
152
153


154
def test_no_model_tag(parser_with_config, cli_config_file):
155
    with pytest.raises(ValueError):
156
        parser_with_config.parse_args(["serve", "--config", cli_config_file])
157
158


159
160
161
162
163
def test_dict_args(parser):
    args = [
        "--model-name=something.something",
        "--hf-overrides.key1",
        "val1",
164
        # Test nesting
165
166
167
168
        "--hf-overrides.key2.key3",
        "val2",
        "--hf-overrides.key2.key4",
        "val3",
169
        # Test compile config and compilation mode
170
171
        "-cc.use_inductor_graph_partition=true",
        "-cc.backend",
172
173
        "custom",
        "-O1",
174
        # Test = sign
175
        "--hf-overrides.key5=val4",
176
177
178
179
180
        # Test underscore to dash conversion
        "--hf_overrides.key_6",
        "val5",
        "--hf_overrides.key-7.key_8",
        "val6",
181
182
183
184
185
186
187
188
189
        # Test data type detection
        "--hf_overrides.key9",
        "100",
        "--hf_overrides.key10",
        "100.0",
        "--hf_overrides.key11",
        "true",
        "--hf_overrides.key12.key13",
        "null",
190
191
192
193
        # Test '-' and '.' in value
        "--hf_overrides.key14.key15",
        "-minus.and.dot",
        # Test array values
194
        "-cc.custom_ops+",
195
        "-quant_fp8",
196
        "-cc.custom_ops+=+silu_mul,-rms_norm",
197
198
199
200
201
202
203
204
205
206
    ]
    parsed_args = parser.parse_args(args)
    assert parsed_args.model_name == "something.something"
    assert parsed_args.hf_overrides == {
        "key1": "val1",
        "key2": {
            "key3": "val2",
            "key4": "val3",
        },
        "key5": "val4",
207
208
209
210
        "key_6": "val5",
        "key-7": {
            "key_8": "val6",
        },
211
212
213
214
215
216
        "key9": 100,
        "key10": 100.0,
        "key11": True,
        "key12": {
            "key13": None,
        },
217
218
        "key14": {
            "key15": "-minus.and.dot",
219
        },
220
    }
221
    assert parsed_args.optimization_level == 1
222
    assert parsed_args.compilation_config == {
223
        "use_inductor_graph_partition": True,
224
225
226
227
228
229
230
231
232
233
234
235
236
        "backend": "custom",
        "custom_ops": ["-quant_fp8", "+silu_mul", "-rms_norm"],
    }


def test_duplicate_dict_args(caplog_vllm, parser):
    args = [
        "--model-name=something.something",
        "--hf-overrides.key1",
        "val1",
        "--hf-overrides.key1",
        "val2",
        "-O1",
237
        "-cc.mode",
238
239
240
241
242
243
244
        "2",
        "-O3",
    ]

    parsed_args = parser.parse_args(args)
    # Should be the last value
    assert parsed_args.hf_overrides == {"key1": "val2"}
245
246
    assert parsed_args.optimization_level == 3
    assert parsed_args.compilation_config == {"mode": 2}
247
248
249
250

    assert len(caplog_vllm.records) == 1
    assert "duplicate" in caplog_vllm.text
    assert "--hf-overrides.key1" in caplog_vllm.text
251
    assert "--optimization-level" in caplog_vllm.text
252
253


254
255
256
def test_model_specification(
    parser_with_config, cli_config_file, cli_config_file_with_model
):
257
    # Test model in CLI takes precedence over config
258
    args = parser_with_config.parse_args(
259
260
261
262
        ["serve", "cli-model", "--config", cli_config_file_with_model]
    )
    assert args.model_tag == "cli-model"
    assert args.served_model_name == "mymodel"
263
264

    # Test model from config file works
265
266
267
268
269
270
271
272
273
    args = parser_with_config.parse_args(
        [
            "serve",
            "--config",
            cli_config_file_with_model,
        ]
    )
    assert args.model == "config-model"
    assert args.served_model_name == "mymodel"
274
275
276

    # Test no model specified anywhere raises error
    with pytest.raises(ValueError, match="No model specified!"):
277
        parser_with_config.parse_args(["serve", "--config", cli_config_file])
278
279

    # Test using --model option raises error
280
281
282
283
284
285
286
287
288
289
290
    # with pytest.raises(
    #         ValueError,
    #         match=
    #     ("With `vllm serve`, you should provide the model as a positional "
    #      "argument or in a config file instead of via the `--model` option."),
    # ):
    #     parser_with_config.parse_args(['serve', '--model', 'my-model'])

    # Test using --model option back-compatibility
    # (when back-compatibility ends, the above test should be uncommented
    # and the below test should be removed)
291
292
293
294
295
296
297
298
299
300
301
302
    args = parser_with_config.parse_args(
        [
            "serve",
            "--tensor-parallel-size",
            "2",
            "--model",
            "my-model",
            "--trust-remote-code",
            "--port",
            "8001",
        ]
    )
303
304
305
306
307
    assert args.model is None
    assert args.tensor_parallel_size == 2
    assert args.trust_remote_code is True
    assert args.port == 8001

308
309
310
311
312
313
314
315
316
    args = parser_with_config.parse_args(
        [
            "serve",
            "--tensor-parallel-size=2",
            "--model=my-model",
            "--trust-remote-code",
            "--port=8001",
        ]
    )
317
318
319
320
    assert args.model is None
    assert args.tensor_parallel_size == 2
    assert args.trust_remote_code is True
    assert args.port == 8001
321
322

    # Test other config values are preserved
323
324
325
326
327
328
329
330
    args = parser_with_config.parse_args(
        [
            "serve",
            "cli-model",
            "--config",
            cli_config_file_with_model,
        ]
    )
331
332
333
334
335
    assert args.tensor_parallel_size == 2
    assert args.trust_remote_code is True
    assert args.port == 12312


336
337
338
339
def test_convert_ids_list_to_tokens():
    tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct")
    token_ids = tokenizer.encode("Hello, world!")
    # token_ids = [9707, 11, 1879, 0]
340
    assert tokenizer.convert_ids_to_tokens(token_ids) == ["Hello", ",", "Ġworld", "!"]
341
    tokens = convert_ids_list_to_tokens(tokenizer, token_ids)
342
    assert tokens == ["Hello", ",", " world", "!"]
343
344


345
346
347
348
349
350
def test_load_config_file(tmp_path):
    # Define the configuration data
    config_data = {
        "enable-logging": True,
        "list-arg": ["item1", "item2"],
        "port": 12323,
351
        "tensor-parallel-size": 4,
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
    }

    # Write the configuration data to a temporary YAML file
    config_file_path = tmp_path / "config.yaml"
    with open(config_file_path, "w") as config_file:
        yaml.dump(config_data, config_file)

    # Initialize the parser
    parser = FlexibleArgumentParser()

    # Call the function with the temporary file path
    processed_args = parser.load_config_file(str(config_file_path))

    # Expected output
    expected_args = [
        "--enable-logging",
        "--list-arg",
        "item1",
        "item2",
        "--port",
        "12323",
        "--tensor-parallel-size",
        "4",
    ]

    # Assert that the processed arguments match the expected output
    assert processed_args == expected_args
    os.remove(str(config_file_path))
380
381


382
def test_compilation_mode_string_values(parser):
383
384
    """Test that -cc.mode accepts both integer and string mode values."""
    args = parser.parse_args(["-cc.mode", "0"])
385
386
387
    assert args.compilation_config == {"mode": 0}

    args = parser.parse_args(["-O3"])
388
    assert args.optimization_level == 3
389

390
    args = parser.parse_args(["-cc.mode=NONE"])
391
392
    assert args.compilation_config == {"mode": "NONE"}

393
    args = parser.parse_args(["-cc.mode", "STOCK_TORCH_COMPILE"])
394
395
    assert args.compilation_config == {"mode": "STOCK_TORCH_COMPILE"}

396
    args = parser.parse_args(["-cc.mode=DYNAMO_TRACE_ONCE"])
397
398
    assert args.compilation_config == {"mode": "DYNAMO_TRACE_ONCE"}

399
    args = parser.parse_args(["-cc.mode", "VLLM_COMPILE"])
400
401
    assert args.compilation_config == {"mode": "VLLM_COMPILE"}

402
    args = parser.parse_args(["-cc.mode=none"])
403
404
    assert args.compilation_config == {"mode": "none"}

405
    args = parser.parse_args(["-cc.mode=vllm_compile"])
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
    assert args.compilation_config == {"mode": "vllm_compile"}


def test_compilation_config_mode_validator():
    """Test that CompilationConfig.mode field validator converts strings to integers."""
    from vllm.config.compilation import CompilationConfig, CompilationMode

    config = CompilationConfig(mode=0)
    assert config.mode == CompilationMode.NONE

    config = CompilationConfig(mode=3)
    assert config.mode == CompilationMode.VLLM_COMPILE

    config = CompilationConfig(mode="NONE")
    assert config.mode == CompilationMode.NONE

    config = CompilationConfig(mode="STOCK_TORCH_COMPILE")
    assert config.mode == CompilationMode.STOCK_TORCH_COMPILE

    config = CompilationConfig(mode="DYNAMO_TRACE_ONCE")
    assert config.mode == CompilationMode.DYNAMO_TRACE_ONCE

    config = CompilationConfig(mode="VLLM_COMPILE")
    assert config.mode == CompilationMode.VLLM_COMPILE

    config = CompilationConfig(mode="none")
    assert config.mode == CompilationMode.NONE

    config = CompilationConfig(mode="vllm_compile")
    assert config.mode == CompilationMode.VLLM_COMPILE

    with pytest.raises(ValidationError, match="Invalid compilation mode"):
        CompilationConfig(mode="INVALID_MODE")


441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
def test_flat_product():
    # Check regular itertools.product behavior
    result1 = list(flat_product([1, 2, 3], ["a", "b"]))
    assert result1 == [
        (1, "a"),
        (1, "b"),
        (2, "a"),
        (2, "b"),
        (3, "a"),
        (3, "b"),
    ]

    # check that the tuples get flattened
    result2 = list(flat_product([(1, 2), (3, 4)], ["a", "b"], [(5, 6)]))
    assert result2 == [
        (1, 2, "a", 5, 6),
        (1, 2, "b", 5, 6),
        (3, 4, "a", 5, 6),
        (3, 4, "b", 5, 6),
    ]