test_cli_args.py 5.84 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import json

6
7
8
9
import pytest

from vllm.entrypoints.openai.cli_args import (make_arg_parser,
                                              validate_parsed_serve_args)
10
from vllm.entrypoints.openai.serving_models import LoRAModulePath
11
12
from vllm.utils import FlexibleArgumentParser

13
14
from ...utils import VLLM_PATH

15
16
17
18
19
LORA_MODULE = {
    "name": "module2",
    "path": "/path/to/module2",
    "base_model_name": "llama"
}
20
21
CHATML_JINJA_PATH = VLLM_PATH / "examples/template_chatml.jinja"
assert CHATML_JINJA_PATH.exists()
22
23


24
25
26
27
@pytest.fixture
def serve_parser():
    parser = FlexibleArgumentParser(description="vLLM's remote OpenAI server.")
    return make_arg_parser(parser)
28
29


30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
### Test config parsing
def test_config_arg_parsing(serve_parser, cli_config_file):
    args = serve_parser.parse_args([])
    assert args.port == 8000
    args = serve_parser.parse_args(['--config', cli_config_file])
    assert args.port == 12312
    args = serve_parser.parse_args([
        '--config',
        cli_config_file,
        '--port',
        '9000',
    ])
    assert args.port == 9000
    args = serve_parser.parse_args([
        '--port',
        '9000',
        '--config',
        cli_config_file,
    ])
    assert args.port == 9000


52
### Tests for LoRA module parsing
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
def test_valid_key_value_format(serve_parser):
    # Test old format: name=path
    args = serve_parser.parse_args([
        '--lora-modules',
        'module1=/path/to/module1',
    ])
    expected = [LoRAModulePath(name='module1', path='/path/to/module1')]
    assert args.lora_modules == expected


def test_valid_json_format(serve_parser):
    # Test valid JSON format input
    args = serve_parser.parse_args([
        '--lora-modules',
        json.dumps(LORA_MODULE),
    ])
    expected = [
        LoRAModulePath(name='module2',
                       path='/path/to/module2',
                       base_model_name='llama')
    ]
    assert args.lora_modules == expected


def test_invalid_json_format(serve_parser):
    # Test invalid JSON format input, missing closing brace
    with pytest.raises(SystemExit):
        serve_parser.parse_args([
            '--lora-modules', '{"name": "module3", "path": "/path/to/module3"'
82
83
        ])

84
85
86
87
88

def test_invalid_type_error(serve_parser):
    # Test type error when values are not JSON or key=value
    with pytest.raises(SystemExit):
        serve_parser.parse_args([
89
            '--lora-modules',
90
            'invalid_format'  # This is not JSON or key=value format
91
        ])
92
93
94
95
96
97


def test_invalid_json_field(serve_parser):
    # Test valid JSON format but missing required fields
    with pytest.raises(SystemExit):
        serve_parser.parse_args([
98
            '--lora-modules',
99
            '{"name": "module4"}'  # Missing required 'path' field
100
101
102
        ])


103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
def test_empty_values(serve_parser):
    # Test when no LoRA modules are provided
    args = serve_parser.parse_args(['--lora-modules', ''])
    assert args.lora_modules == []


def test_multiple_valid_inputs(serve_parser):
    # Test multiple valid inputs (both old and JSON format)
    args = serve_parser.parse_args([
        '--lora-modules',
        'module1=/path/to/module1',
        json.dumps(LORA_MODULE),
    ])
    expected = [
        LoRAModulePath(name='module1', path='/path/to/module1'),
        LoRAModulePath(name='module2',
                       path='/path/to/module2',
                       base_model_name='llama')
    ]
    assert args.lora_modules == expected


### Tests for serve argument validation that run prior to loading
def test_enable_auto_choice_passes_without_tool_call_parser(serve_parser):
    """Ensure validation fails if tool choice is enabled with no call parser"""
    # If we enable-auto-tool-choice, explode with no tool-call-parser
    args = serve_parser.parse_args(args=["--enable-auto-tool-choice"])
    with pytest.raises(TypeError):
        validate_parsed_serve_args(args)


def test_enable_auto_choice_passes_with_tool_call_parser(serve_parser):
    """Ensure validation passes with tool choice enabled with a call parser"""
    args = serve_parser.parse_args(args=[
        "--enable-auto-tool-choice",
        "--tool-call-parser",
        "mistral",
    ])
    validate_parsed_serve_args(args)


144
145
146
147
def test_enable_auto_choice_fails_with_enable_reasoning(serve_parser):
    """Ensure validation fails if reasoning is enabled with auto tool choice"""
    args = serve_parser.parse_args(args=[
        "--enable-auto-tool-choice",
148
149
        "--reasoning-parser",
        "deepseek_r1",
150
151
152
153
154
    ])
    with pytest.raises(TypeError):
        validate_parsed_serve_args(args)


155
def test_passes_with_reasoning_parser(serve_parser):
156
157
158
159
160
161
162
163
164
    """Ensure validation passes if reasoning is enabled 
    with a reasoning parser"""
    args = serve_parser.parse_args(args=[
        "--reasoning-parser",
        "deepseek_r1",
    ])
    validate_parsed_serve_args(args)


165
166
167
168
169
170
171
172
173
174
175
176
177
def test_chat_template_validation_for_happy_paths(serve_parser):
    """Ensure validation passes if the chat template exists"""
    args = serve_parser.parse_args(
        args=["--chat-template",
              CHATML_JINJA_PATH.absolute().as_posix()])
    validate_parsed_serve_args(args)


def test_chat_template_validation_for_sad_paths(serve_parser):
    """Ensure validation fails if the chat template doesn't exist"""
    args = serve_parser.parse_args(args=["--chat-template", "does/not/exist"])
    with pytest.raises(ValueError):
        validate_parsed_serve_args(args)
178
179
180
181
182
183
184
185
186
187


@pytest.mark.parametrize(
    "cli_args, expected_middleware",
    [(["--middleware", "middleware1", "--middleware", "middleware2"
       ], ["middleware1", "middleware2"]), ([], [])])
def test_middleware(serve_parser, cli_args, expected_middleware):
    """Ensure multiple middleware args are parsed properly"""
    args = serve_parser.parse_args(args=cli_args)
    assert args.middleware == expected_middleware