test_cli_args.py 9.17 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import json

6
7
import pytest

8
from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args
9
from vllm.entrypoints.openai.models.protocol import LoRAModulePath
10
from vllm.utils.argparse_utils import FlexibleArgumentParser
11

12
13
from ...utils import VLLM_PATH

14
15
16
LORA_MODULE = {
    "name": "module2",
    "path": "/path/to/module2",
17
    "base_model_name": "llama",
18
}
19
20
CHATML_JINJA_PATH = VLLM_PATH / "examples/template_chatml.jinja"
assert CHATML_JINJA_PATH.exists()
21
22


23
24
25
26
27
28
29
30
31
32
33
34
35
def _build_vllm_parsers():
    vllm_parser = FlexibleArgumentParser()
    subparsers = vllm_parser.add_subparsers()
    serve_parser = subparsers.add_parser("serve")
    make_arg_parser(serve_parser)
    return {"vllm": vllm_parser, "vllm serve": serve_parser}


@pytest.fixture
def vllm_parser():
    return _build_vllm_parsers()["vllm"]


36
37
@pytest.fixture
def serve_parser():
38
    return _build_vllm_parsers()["vllm serve"]
39
40


41
42
43
44
### Test config parsing
def test_config_arg_parsing(serve_parser, cli_config_file):
    args = serve_parser.parse_args([])
    assert args.port == 8000
45
    args = serve_parser.parse_args(["--config", cli_config_file])
46
    assert args.port == 12312
47
48
49
50
51
52
53
54
    args = serve_parser.parse_args(
        [
            "--config",
            cli_config_file,
            "--port",
            "9000",
        ]
    )
55
    assert args.port == 9000
56
57
58
59
60
61
62
63
    args = serve_parser.parse_args(
        [
            "--port",
            "9000",
            "--config",
            cli_config_file,
        ]
    )
64
65
66
    assert args.port == 9000


67
### Tests for LoRA module parsing
68
69
def test_valid_key_value_format(serve_parser):
    # Test old format: name=path
70
71
72
73
74
75
76
    args = serve_parser.parse_args(
        [
            "--lora-modules",
            "module1=/path/to/module1",
        ]
    )
    expected = [LoRAModulePath(name="module1", path="/path/to/module1")]
77
78
79
80
81
    assert args.lora_modules == expected


def test_valid_json_format(serve_parser):
    # Test valid JSON format input
82
83
84
85
86
87
    args = serve_parser.parse_args(
        [
            "--lora-modules",
            json.dumps(LORA_MODULE),
        ]
    )
88
    expected = [
89
        LoRAModulePath(name="module2", path="/path/to/module2", base_model_name="llama")
90
91
92
93
94
95
96
    ]
    assert args.lora_modules == expected


def test_invalid_json_format(serve_parser):
    # Test invalid JSON format input, missing closing brace
    with pytest.raises(SystemExit):
97
98
99
        serve_parser.parse_args(
            ["--lora-modules", '{"name": "module3", "path": "/path/to/module3"']
        )
100

101
102
103
104

def test_invalid_type_error(serve_parser):
    # Test type error when values are not JSON or key=value
    with pytest.raises(SystemExit):
105
106
107
108
109
110
        serve_parser.parse_args(
            [
                "--lora-modules",
                "invalid_format",  # This is not JSON or key=value format
            ]
        )
111
112
113
114
115


def test_invalid_json_field(serve_parser):
    # Test valid JSON format but missing required fields
    with pytest.raises(SystemExit):
116
117
118
119
120
121
        serve_parser.parse_args(
            [
                "--lora-modules",
                '{"name": "module4"}',  # Missing required 'path' field
            ]
        )
122
123


124
125
def test_empty_values(serve_parser):
    # Test when no LoRA modules are provided
126
    args = serve_parser.parse_args(["--lora-modules", ""])
127
128
129
130
131
    assert args.lora_modules == []


def test_multiple_valid_inputs(serve_parser):
    # Test multiple valid inputs (both old and JSON format)
132
133
134
135
136
137
138
    args = serve_parser.parse_args(
        [
            "--lora-modules",
            "module1=/path/to/module1",
            json.dumps(LORA_MODULE),
        ]
    )
139
    expected = [
140
141
142
143
        LoRAModulePath(name="module1", path="/path/to/module1"),
        LoRAModulePath(
            name="module2", path="/path/to/module2", base_model_name="llama"
        ),
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
    ]
    assert args.lora_modules == expected


### Tests for serve argument validation that run prior to loading
def test_enable_auto_choice_passes_without_tool_call_parser(serve_parser):
    """Ensure validation fails if tool choice is enabled with no call parser"""
    # If we enable-auto-tool-choice, explode with no tool-call-parser
    args = serve_parser.parse_args(args=["--enable-auto-tool-choice"])
    with pytest.raises(TypeError):
        validate_parsed_serve_args(args)


def test_enable_auto_choice_passes_with_tool_call_parser(serve_parser):
    """Ensure validation passes with tool choice enabled with a call parser"""
159
160
161
162
163
164
165
    args = serve_parser.parse_args(
        args=[
            "--enable-auto-tool-choice",
            "--tool-call-parser",
            "mistral",
        ]
    )
166
167
168
    validate_parsed_serve_args(args)


169
170
def test_enable_auto_choice_fails_with_enable_reasoning(serve_parser):
    """Ensure validation fails if reasoning is enabled with auto tool choice"""
171
172
173
174
175
176
177
    args = serve_parser.parse_args(
        args=[
            "--enable-auto-tool-choice",
            "--reasoning-parser",
            "deepseek_r1",
        ]
    )
178
179
180
181
    with pytest.raises(TypeError):
        validate_parsed_serve_args(args)


182
def test_passes_with_reasoning_parser(serve_parser):
183
    """Ensure validation passes if reasoning is enabled
184
    with a reasoning parser"""
185
186
187
188
189
190
    args = serve_parser.parse_args(
        args=[
            "--reasoning-parser",
            "deepseek_r1",
        ]
    )
191
192
193
    validate_parsed_serve_args(args)


194
195
196
def test_chat_template_validation_for_happy_paths(serve_parser):
    """Ensure validation passes if the chat template exists"""
    args = serve_parser.parse_args(
197
198
        args=["--chat-template", CHATML_JINJA_PATH.absolute().as_posix()]
    )
199
200
201
202
203
204
205
206
    validate_parsed_serve_args(args)


def test_chat_template_validation_for_sad_paths(serve_parser):
    """Ensure validation fails if the chat template doesn't exist"""
    args = serve_parser.parse_args(args=["--chat-template", "does/not/exist"])
    with pytest.raises(ValueError):
        validate_parsed_serve_args(args)
207
208
209
210


@pytest.mark.parametrize(
    "cli_args, expected_middleware",
211
212
213
214
215
216
217
218
    [
        (
            ["--middleware", "middleware1", "--middleware", "middleware2"],
            ["middleware1", "middleware2"],
        ),
        ([], []),
    ],
)
219
220
221
222
def test_middleware(serve_parser, cli_args, expected_middleware):
    """Ensure multiple middleware args are parsed properly"""
    args = serve_parser.parse_args(args=cli_args)
    assert args.middleware == expected_middleware
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255


def test_default_chat_template_kwargs_parsing(serve_parser):
    """Ensure default_chat_template_kwargs JSON is parsed correctly"""
    args = serve_parser.parse_args(
        args=["--default-chat-template-kwargs", '{"enable_thinking": false}']
    )
    assert args.default_chat_template_kwargs == {"enable_thinking": False}


def test_default_chat_template_kwargs_complex(serve_parser):
    """Ensure complex default_chat_template_kwargs JSON is parsed correctly"""
    kwargs_json = '{"enable_thinking": false, "custom_param": "value", "num": 42}'
    args = serve_parser.parse_args(args=["--default-chat-template-kwargs", kwargs_json])
    assert args.default_chat_template_kwargs == {
        "enable_thinking": False,
        "custom_param": "value",
        "num": 42,
    }


def test_default_chat_template_kwargs_default_none(serve_parser):
    """Ensure default_chat_template_kwargs defaults to None"""
    args = serve_parser.parse_args(args=[])
    assert args.default_chat_template_kwargs is None


def test_default_chat_template_kwargs_invalid_json(serve_parser):
    """Ensure invalid JSON raises an error"""
    with pytest.raises(SystemExit):
        serve_parser.parse_args(
            args=["--default-chat-template-kwargs", "not valid json"]
        )
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293


@pytest.mark.parametrize(
    "args, raises",
    [
        (["user/model"], None),
        (["user/model", "--served-model-name", "model"], None),
        (["--served-model-name", "model", "user/model"], ValueError),
        (["--served-model-name", "model", "--config", "config.yaml"], None),
        (["--served-model-name", "model", "--config", "config.yaml"], ValueError),
    ],
    ids=[
        "model_tag_only",
        "model_tag_with_served_model_name",
        "served_model_name_before_model_tag",
        "served_model_name_with_model_in_config",
        "served_model_name_with_no_model_in_config",
    ],
)
def test_served_model_name_parsing(tmp_path, vllm_parser, args, raises):
    """Ensure that users don't misuse --served-model-name and end up with the default
    model tag instead of the one they intended to serve."""
    # Call the serve subparser
    args.insert(0, "serve")
    # Create a dummy config file if the test case includes it
    if "config.yaml" in args:
        # Create a dummy config file if the test case includes it
        config_path = tmp_path / "config.yaml"
        config_path.write_text("model: user/model" if raises is None else "port: 8000")
        args[args.index("config.yaml")] = config_path.as_posix()
    # Do the parsing and check for expected exceptions or values
    if raises is None:
        parsed_args = vllm_parser.parse_args(args=args)
        expected = "user/model"
        assert parsed_args.model_tag == expected or parsed_args.model == expected
    else:
        with pytest.raises(raises):
            vllm_parser.parse_args(args=args)