test_argparse_utils.py 15 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
# ruff: noqa
4

5
import json
6
import os
7

8
import pytest
9
import yaml
10
from transformers import AutoTokenizer
11
from pydantic import ValidationError
12

13
from vllm.tokenizers.detokenizer_utils import convert_ids_list_to_tokens
14

Cyrus Leung's avatar
Cyrus Leung committed
15
16
from vllm.utils.argparse_utils import FlexibleArgumentParser
from ..utils import flat_product
17

18

19
20
21
22
# Tests for FlexibleArgumentParser
@pytest.fixture
def parser():
    parser = FlexibleArgumentParser()
23
24
25
26
27
28
29
    parser.add_argument(
        "--image-input-type", choices=["pixel_values", "image_features"]
    )
    parser.add_argument("--model-name")
    parser.add_argument("--batch-size", type=int)
    parser.add_argument("--enable-feature", action="store_true")
    parser.add_argument("--hf-overrides", type=json.loads)
30
    parser.add_argument("-cc", "--compilation-config", type=json.loads)
31
    parser.add_argument("--optimization-level", type=int)
32
33
34
    return parser


35
36
37
@pytest.fixture
def parser_with_config():
    parser = FlexibleArgumentParser()
38
39
40
41
42
43
44
45
    parser.add_argument("serve")
    parser.add_argument("model_tag", nargs="?")
    parser.add_argument("--model", type=str)
    parser.add_argument("--served-model-name", type=str)
    parser.add_argument("--config", type=str)
    parser.add_argument("--port", type=int)
    parser.add_argument("--tensor-parallel-size", type=int)
    parser.add_argument("--trust-remote-code", action="store_true")
46
47
48
    return parser


49
def test_underscore_to_dash(parser):
50
51
    args = parser.parse_args(["--image_input_type", "pixel_values"])
    assert args.image_input_type == "pixel_values"
52
53
54


def test_mixed_usage(parser):
55
56
57
58
59
    args = parser.parse_args(
        ["--image_input_type", "image_features", "--model-name", "facebook/opt-125m"]
    )
    assert args.image_input_type == "image_features"
    assert args.model_name == "facebook/opt-125m"
60
61
62
63


def test_with_equals_sign(parser):
    args = parser.parse_args(
64
65
66
67
        ["--image_input_type=pixel_values", "--model-name=facebook/opt-125m"]
    )
    assert args.image_input_type == "pixel_values"
    assert args.model_name == "facebook/opt-125m"
68
69
70


def test_with_int_value(parser):
71
    args = parser.parse_args(["--batch_size", "32"])
72
    assert args.batch_size == 32
73
    args = parser.parse_args(["--batch-size", "32"])
74
75
76
77
    assert args.batch_size == 32


def test_with_bool_flag(parser):
78
    args = parser.parse_args(["--enable_feature"])
79
    assert args.enable_feature is True
80
    args = parser.parse_args(["--enable-feature"])
81
82
83
84
85
    assert args.enable_feature is True


def test_invalid_choice(parser):
    with pytest.raises(SystemExit):
86
        parser.parse_args(["--image_input_type", "invalid_choice"])
87
88
89


def test_missing_required_argument(parser):
90
    parser.add_argument("--required-arg", required=True)
91
92
    with pytest.raises(SystemExit):
        parser.parse_args([])
93
94


95
def test_cli_override_to_config(parser_with_config, cli_config_file):
96
97
98
    args = parser_with_config.parse_args(
        ["serve", "mymodel", "--config", cli_config_file, "--tensor-parallel-size", "3"]
    )
99
    assert args.tensor_parallel_size == 3
100
101
102
    args = parser_with_config.parse_args(
        ["serve", "mymodel", "--tensor-parallel-size", "3", "--config", cli_config_file]
    )
103
    assert args.tensor_parallel_size == 3
104
    assert args.port == 12312
105
106
107
108
109
110
111
112
113
114
115
116
    args = parser_with_config.parse_args(
        [
            "serve",
            "mymodel",
            "--tensor-parallel-size",
            "3",
            "--config",
            cli_config_file,
            "--port",
            "666",
        ]
    )
117
118
    assert args.tensor_parallel_size == 3
    assert args.port == 666
119
120


121
def test_config_args(parser_with_config, cli_config_file):
122
    args = parser_with_config.parse_args(
123
124
        ["serve", "mymodel", "--config", cli_config_file]
    )
125
    assert args.tensor_parallel_size == 2
126
    assert args.trust_remote_code
127
128
129
130


def test_config_file(parser_with_config):
    with pytest.raises(FileNotFoundError):
131
        parser_with_config.parse_args(
132
133
            ["serve", "mymodel", "--config", "test_config.yml"]
        )
134
135
136

    with pytest.raises(ValueError):
        parser_with_config.parse_args(
137
138
            ["serve", "mymodel", "--config", "./data/test_config.json"]
        )
139
140

    with pytest.raises(ValueError):
141
142
143
144
145
146
147
148
149
150
151
        parser_with_config.parse_args(
            [
                "serve",
                "mymodel",
                "--tensor-parallel-size",
                "3",
                "--config",
                "--batch-size",
                "32",
            ]
        )
152
153


154
def test_no_model_tag(parser_with_config, cli_config_file):
155
    with pytest.raises(ValueError):
156
        parser_with_config.parse_args(["serve", "--config", cli_config_file])
157
158


159
160
161
162
163
def test_dict_args(parser):
    args = [
        "--model-name=something.something",
        "--hf-overrides.key1",
        "val1",
164
        # Test nesting
165
166
167
168
        "--hf-overrides.key2.key3",
        "val2",
        "--hf-overrides.key2.key4",
        "val3",
169
        # Test compile config and compilation mode
170
171
        "-cc.use_inductor_graph_partition=true",
        "-cc.backend",
172
173
        "custom",
        "-O1",
174
        # Test = sign
175
        "--hf-overrides.key5=val4",
176
177
178
179
180
        # Test underscore to dash conversion
        "--hf_overrides.key_6",
        "val5",
        "--hf_overrides.key-7.key_8",
        "val6",
181
182
183
184
185
186
187
188
189
        # Test data type detection
        "--hf_overrides.key9",
        "100",
        "--hf_overrides.key10",
        "100.0",
        "--hf_overrides.key11",
        "true",
        "--hf_overrides.key12.key13",
        "null",
190
191
192
193
        # Test '-' and '.' in value
        "--hf_overrides.key14.key15",
        "-minus.and.dot",
        # Test array values
194
        "-cc.custom_ops+",
195
        "-quant_fp8",
196
        "-cc.custom_ops+=+silu_mul,-rms_norm",
197
198
199
200
201
202
203
204
205
206
    ]
    parsed_args = parser.parse_args(args)
    assert parsed_args.model_name == "something.something"
    assert parsed_args.hf_overrides == {
        "key1": "val1",
        "key2": {
            "key3": "val2",
            "key4": "val3",
        },
        "key5": "val4",
207
208
209
210
        "key_6": "val5",
        "key-7": {
            "key_8": "val6",
        },
211
212
213
214
215
216
        "key9": 100,
        "key10": 100.0,
        "key11": True,
        "key12": {
            "key13": None,
        },
217
218
        "key14": {
            "key15": "-minus.and.dot",
219
        },
220
    }
221
    assert parsed_args.optimization_level == 1
222
    assert parsed_args.compilation_config == {
223
        "use_inductor_graph_partition": True,
224
225
226
227
228
229
230
231
232
233
234
235
236
        "backend": "custom",
        "custom_ops": ["-quant_fp8", "+silu_mul", "-rms_norm"],
    }


def test_duplicate_dict_args(caplog_vllm, parser):
    args = [
        "--model-name=something.something",
        "--hf-overrides.key1",
        "val1",
        "--hf-overrides.key1",
        "val2",
        "-O1",
237
        "-cc.mode",
238
239
240
241
242
243
244
        "2",
        "-O3",
    ]

    parsed_args = parser.parse_args(args)
    # Should be the last value
    assert parsed_args.hf_overrides == {"key1": "val2"}
245
246
    assert parsed_args.optimization_level == 3
    assert parsed_args.compilation_config == {"mode": 2}
247
248
249
250

    assert len(caplog_vllm.records) == 1
    assert "duplicate" in caplog_vllm.text
    assert "--hf-overrides.key1" in caplog_vllm.text
251
    assert "--optimization-level" in caplog_vllm.text
252
253


254
255
256
def test_model_specification(
    parser_with_config, cli_config_file, cli_config_file_with_model
):
257
    # Test model in CLI takes precedence over config
258
    args = parser_with_config.parse_args(
259
260
261
262
        ["serve", "cli-model", "--config", cli_config_file_with_model]
    )
    assert args.model_tag == "cli-model"
    assert args.served_model_name == "mymodel"
263
264

    # Test model from config file works
265
266
267
268
269
270
271
272
273
    args = parser_with_config.parse_args(
        [
            "serve",
            "--config",
            cli_config_file_with_model,
        ]
    )
    assert args.model == "config-model"
    assert args.served_model_name == "mymodel"
274
275
276

    # Test no model specified anywhere raises error
    with pytest.raises(ValueError, match="No model specified!"):
277
        parser_with_config.parse_args(["serve", "--config", cli_config_file])
278
279

    # Test using --model option raises error
280
281
282
283
284
285
286
287
288
289
290
    # with pytest.raises(
    #         ValueError,
    #         match=
    #     ("With `vllm serve`, you should provide the model as a positional "
    #      "argument or in a config file instead of via the `--model` option."),
    # ):
    #     parser_with_config.parse_args(['serve', '--model', 'my-model'])

    # Test using --model option back-compatibility
    # (when back-compatibility ends, the above test should be uncommented
    # and the below test should be removed)
291
292
293
294
295
296
297
298
299
300
301
302
    args = parser_with_config.parse_args(
        [
            "serve",
            "--tensor-parallel-size",
            "2",
            "--model",
            "my-model",
            "--trust-remote-code",
            "--port",
            "8001",
        ]
    )
303
304
305
306
307
    assert args.model is None
    assert args.tensor_parallel_size == 2
    assert args.trust_remote_code is True
    assert args.port == 8001

308
309
310
311
312
313
314
315
316
    args = parser_with_config.parse_args(
        [
            "serve",
            "--tensor-parallel-size=2",
            "--model=my-model",
            "--trust-remote-code",
            "--port=8001",
        ]
    )
317
318
319
320
    assert args.model is None
    assert args.tensor_parallel_size == 2
    assert args.trust_remote_code is True
    assert args.port == 8001
321
322

    # Test other config values are preserved
323
324
325
326
327
328
329
330
    args = parser_with_config.parse_args(
        [
            "serve",
            "cli-model",
            "--config",
            cli_config_file_with_model,
        ]
    )
331
332
333
334
335
    assert args.tensor_parallel_size == 2
    assert args.trust_remote_code is True
    assert args.port == 12312


336
337
338
339
def test_convert_ids_list_to_tokens():
    tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct")
    token_ids = tokenizer.encode("Hello, world!")
    # token_ids = [9707, 11, 1879, 0]
340
    assert tokenizer.convert_ids_to_tokens(token_ids) == ["Hello", ",", "Ġworld", "!"]
341
    tokens = convert_ids_list_to_tokens(tokenizer, token_ids)
342
    assert tokens == ["Hello", ",", " world", "!"]
343
344


345
346
347
348
349
350
def test_load_config_file(tmp_path):
    # Define the configuration data
    config_data = {
        "enable-logging": True,
        "list-arg": ["item1", "item2"],
        "port": 12323,
351
        "tensor-parallel-size": 4,
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
    }

    # Write the configuration data to a temporary YAML file
    config_file_path = tmp_path / "config.yaml"
    with open(config_file_path, "w") as config_file:
        yaml.dump(config_data, config_file)

    # Initialize the parser
    parser = FlexibleArgumentParser()

    # Call the function with the temporary file path
    processed_args = parser.load_config_file(str(config_file_path))

    # Expected output
    expected_args = [
        "--enable-logging",
        "--list-arg",
        "item1",
        "item2",
        "--port",
        "12323",
        "--tensor-parallel-size",
        "4",
    ]

    # Assert that the processed arguments match the expected output
    assert processed_args == expected_args
    os.remove(str(config_file_path))
380
381


382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
def test_load_config_file_nested(tmp_path):
    """Test that nested dicts in YAML config are converted to JSON strings."""
    config_data = {
        "port": 8000,
        "compilation-config": {
            "pass_config": {"fuse_allreduce_rms": True},
        },
    }
    config_file_path = tmp_path / "nested_config.yaml"
    with open(config_file_path, "w") as f:
        yaml.dump(config_data, f)

    parser = FlexibleArgumentParser()
    processed_args = parser.load_config_file(str(config_file_path))

    assert processed_args[processed_args.index("--port") + 1] == "8000"
    cc_value = json.loads(
        processed_args[processed_args.index("--compilation-config") + 1]
    )
    assert cc_value == {"pass_config": {"fuse_allreduce_rms": True}}


def test_nested_config_end_to_end(tmp_path):
    """Test end-to-end parsing of nested configs in YAML files."""
    config_data = {
        "compilation-config": {
            "mode": 3,
            "pass_config": {"fuse_allreduce_rms": True},
        },
    }
    config_file_path = tmp_path / "nested_config.yaml"
    with open(config_file_path, "w") as f:
        yaml.dump(config_data, f)

    parser = FlexibleArgumentParser()
    parser.add_argument("-cc", "--compilation-config", type=json.loads)
    args = parser.parse_args(["--config", str(config_file_path)])

    assert args.compilation_config == {
        "mode": 3,
        "pass_config": {"fuse_allreduce_rms": True},
    }


426
def test_compilation_mode_string_values(parser):
427
428
    """Test that -cc.mode accepts both integer and string mode values."""
    args = parser.parse_args(["-cc.mode", "0"])
429
430
431
    assert args.compilation_config == {"mode": 0}

    args = parser.parse_args(["-O3"])
432
    assert args.optimization_level == 3
433

434
    args = parser.parse_args(["-cc.mode=NONE"])
435
436
    assert args.compilation_config == {"mode": "NONE"}

437
    args = parser.parse_args(["-cc.mode", "STOCK_TORCH_COMPILE"])
438
439
    assert args.compilation_config == {"mode": "STOCK_TORCH_COMPILE"}

440
    args = parser.parse_args(["-cc.mode=DYNAMO_TRACE_ONCE"])
441
442
    assert args.compilation_config == {"mode": "DYNAMO_TRACE_ONCE"}

443
    args = parser.parse_args(["-cc.mode", "VLLM_COMPILE"])
444
445
    assert args.compilation_config == {"mode": "VLLM_COMPILE"}

446
    args = parser.parse_args(["-cc.mode=none"])
447
448
    assert args.compilation_config == {"mode": "none"}

449
    args = parser.parse_args(["-cc.mode=vllm_compile"])
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
    assert args.compilation_config == {"mode": "vllm_compile"}


def test_compilation_config_mode_validator():
    """Test that CompilationConfig.mode field validator converts strings to integers."""
    from vllm.config.compilation import CompilationConfig, CompilationMode

    config = CompilationConfig(mode=0)
    assert config.mode == CompilationMode.NONE

    config = CompilationConfig(mode=3)
    assert config.mode == CompilationMode.VLLM_COMPILE

    config = CompilationConfig(mode="NONE")
    assert config.mode == CompilationMode.NONE

    config = CompilationConfig(mode="STOCK_TORCH_COMPILE")
    assert config.mode == CompilationMode.STOCK_TORCH_COMPILE

    config = CompilationConfig(mode="DYNAMO_TRACE_ONCE")
    assert config.mode == CompilationMode.DYNAMO_TRACE_ONCE

    config = CompilationConfig(mode="VLLM_COMPILE")
    assert config.mode == CompilationMode.VLLM_COMPILE

    config = CompilationConfig(mode="none")
    assert config.mode == CompilationMode.NONE

    config = CompilationConfig(mode="vllm_compile")
    assert config.mode == CompilationMode.VLLM_COMPILE

    with pytest.raises(ValidationError, match="Invalid compilation mode"):
        CompilationConfig(mode="INVALID_MODE")


485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
def test_flat_product():
    # Check regular itertools.product behavior
    result1 = list(flat_product([1, 2, 3], ["a", "b"]))
    assert result1 == [
        (1, "a"),
        (1, "b"),
        (2, "a"),
        (2, "b"),
        (3, "a"),
        (3, "b"),
    ]

    # check that the tuples get flattened
    result2 = list(flat_product([(1, 2), (3, 4)], ["a", "b"], [(5, 6)]))
    assert result2 == [
        (1, 2, "a", 5, 6),
        (1, 2, "b", 5, 6),
        (3, 4, "a", 5, 6),
        (3, 4, "b", 5, 6),
    ]