"vllm/lora/model_manager.py" did not exist on "3ada34f9cb4d1af763fdfa3b481862a93eb6bd2b"
test_mistral.py 9.86 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import copy
4
import json
5

6
7
import pytest

8
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import (
9
10
11
    MistralToolCall,
    MistralToolParser,
)
12
from vllm.sampling_params import SamplingParams
13
from vllm.transformers_utils.tokenizer import MistralTokenizer
14

15
from ...utils import check_logprobs_close
16

17
MODELS = [
18
    "mistralai/Mistral-7B-Instruct-v0.3",
19
20
]

21
22
MISTRAL_FORMAT_MODELS = [
    "mistralai/Mistral-7B-Instruct-v0.3",
23
24
    # uses the v3-Tekken tokenizer
    "mistralai/Ministral-8B-Instruct-2410",
25
    # Mistral-Nemo is too big for CI, but passes locally
26
    # "mistralai/Mistral-Nemo-Instruct-2407"
27
28
]

29
SAMPLING_PARAMS = SamplingParams(max_tokens=512, temperature=0.0, logprobs=5)
30
31
32
SYMBOLIC_LANG_PROMPTS = [
    "勇敢な船乗りについての詩を書く",  # japanese
    "寫一首關於勇敢的水手的詩",  # chinese
33
34
    "ပုံပြင်လေးပြောပြပါ်:\n",  # burmese
    "Repeat the phrase 'URGENCY🌶️':\nURGENCY🌶️\nURGENCY🌶️\n",  # see https://github.com/vllm-project/vllm/pull/9625
35
]
36
37

# for function calling
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
TOOLS = [
    {
        "type": "function",
        "function": {
            "name": "get_current_weather",
            "description": "Get the current weather in a given location",
            "parameters": {
                "type": "object",
                "properties": {
                    "city": {
                        "type": "string",
                        "description": "The city to find the weather for, e.g. 'San Francisco'",
                    },
                    "state": {
                        "type": "string",
                        "description": "the two-letter abbreviation for the state that the city is"
                        " in, e.g. 'CA' which would mean 'California'",
                    },
                    "unit": {
                        "type": "string",
                        "description": "The unit to fetch the temperature in",
                        "enum": ["celsius", "fahrenheit"],
                    },
61
                },
62
                "required": ["city", "state", "unit"],
63
            },
64
        },
65
66
    },
    {
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
        "type": "function",
        "function": {
            "name": "rewrite",
            "description": "Rewrites text",
            "parameters": {
                "type": "object",
                "required": [],
                "properties": {
                    "text": {
                        "type": "string",
                        "description": "The input text to rewrite.",
                    }
                },
            },
        },
82
    },
83
84
85
]
MSGS = [
    {"role": "system", "content": "You are an assistant."},
86
    {
87
88
        "role": "user",
        "content": "Could you please rewrite the below article? \n\n My English needs improvving, maybe I make errors.",  # noqa
89
90
    },
    {
91
92
93
94
95
96
97
98
99
100
        "role": "assistant",
        "content": "",
        "tool_calls": [
            {
                "id": "bbc5b7ede",
                "type": "function",
                "function": {
                    "name": "rewrite",
                    "arguments": '{"text":"My English needs improvving, maybe I make errors."}',  # noqa
                },
101
            }
102
        ],
103
104
105
    },
    {
        "role": "tool",
106
        "content": '{"action":"rewrite","outcome":"My English needs improving, maybe I make errors."}',  # noqa
107
        "tool_call_id": "bbc5b7ede",
108
        "name": "rewrite",
109
110
111
    },
    {
        "role": "assistant",
112
        "content": "---\n\nMy English needs improving, maybe I make errors",
113
114
    },
    {
115
116
117
118
119
        "role": "user",
        "content": (
            "Can you tell me what the temperate will be in Dallas, in fahrenheit?"
        ),
    },
120
]
121

122
123
124
SAMPLE_JSON_SCHEMA = {
    "type": "object",
    "properties": {
125
126
        "name": {"type": "string"},
        "age": {"type": "integer"},
127
128
        "skills": {
            "type": "array",
129
130
            "items": {"type": "string", "maxLength": 10},
            "minItems": 3,
131
132
133
134
135
136
        },
        "work_history": {
            "type": "array",
            "items": {
                "type": "object",
                "properties": {
137
138
139
                    "company": {"type": "string"},
                    "duration": {"type": "number"},
                    "position": {"type": "string"},
140
                },
141
142
143
                "required": ["company", "position"],
            },
        },
144
    },
145
    "required": ["name", "age", "skills", "work_history"],
146
147
}

148
149
150

@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
151
152
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
153
154
155
156
157
158
159
160
161
def test_models(
    hf_runner,
    vllm_runner,
    example_prompts,
    model: str,
    dtype: str,
    max_tokens: int,
    num_logprobs: int,
) -> None:
162
    # TODO(sang): Sliding window should be tested separately.
163
164
    with hf_runner(model, dtype=dtype) as hf_model:
        hf_outputs = hf_model.generate_greedy_logprobs_limit(
165
166
            example_prompts, max_tokens, num_logprobs
        )
167

168
    with vllm_runner(model, dtype=dtype, tokenizer_mode="mistral") as vllm_model:
169
        vllm_outputs = vllm_model.generate_greedy_logprobs(
170
171
            example_prompts, max_tokens, num_logprobs
        )
172

173
174
175
176
177
178
    check_logprobs_close(
        outputs_0_lst=hf_outputs,
        outputs_1_lst=vllm_outputs,
        name_0="hf",
        name_1="vllm",
    )
179
180


181
@pytest.mark.parametrize("model", MISTRAL_FORMAT_MODELS)
182
183
184
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
185
186
187
188
189
190
191
192
def test_mistral_format(
    vllm_runner,
    example_prompts,
    model: str,
    dtype: str,
    max_tokens: int,
    num_logprobs: int,
) -> None:
193
    with vllm_runner(
194
195
196
197
198
        model,
        dtype=dtype,
        tokenizer_mode="mistral",
        load_format="mistral",
        config_format="mistral",
199
200
    ) as mistral_format_model:
        mistral_format_outputs = mistral_format_model.generate_greedy_logprobs(
201
202
            example_prompts, max_tokens, num_logprobs
        )
203

204
    with vllm_runner(
205
206
207
208
209
        model,
        dtype=dtype,
        tokenizer_mode="auto",
        load_format="safetensors",
        config_format="hf",
210
211
    ) as hf_format_model:
        hf_format_outputs = hf_format_model.generate_greedy_logprobs(
212
213
            example_prompts, max_tokens, num_logprobs
        )
214

215
216
217
218
219
220
    check_logprobs_close(
        outputs_0_lst=hf_format_outputs,
        outputs_1_lst=mistral_format_outputs,
        name_0="hf",
        name_1="mistral",
    )
221
222


223
@pytest.mark.parametrize("model", MISTRAL_FORMAT_MODELS)
224
@pytest.mark.parametrize("dtype", ["bfloat16"])
225
226
227
228
229
230
231
232
233
def test_mistral_symbolic_languages(vllm_runner, model: str, dtype: str) -> None:
    with vllm_runner(
        model,
        dtype=dtype,
        max_model_len=8192,
        tokenizer_mode="mistral",
        config_format="mistral",
        load_format="mistral",
    ) as vllm_model:
234
235
        for prompt in SYMBOLIC_LANG_PROMPTS:
            msg = {"role": "user", "content": prompt}
236
            outputs = vllm_model.llm.chat([msg], sampling_params=SAMPLING_PARAMS)
237
            assert "�" not in outputs[0].outputs[0].text.strip()
238
239


240
@pytest.mark.parametrize("model", MISTRAL_FORMAT_MODELS)
241
@pytest.mark.parametrize("dtype", ["bfloat16"])
242
def test_mistral_function_calling(vllm_runner, model: str, dtype: str) -> None:
243
244
245
246
247
248
249
    with vllm_runner(
        model,
        dtype=dtype,
        tokenizer_mode="mistral",
        config_format="mistral",
        load_format="mistral",
    ) as vllm_model:
250
        msgs = copy.deepcopy(MSGS)
251
252
253
        outputs = vllm_model.llm.chat(
            msgs, tools=TOOLS, sampling_params=SAMPLING_PARAMS
        )
254

255
        tokenizer = vllm_model.llm.get_tokenizer()
256
257
258
259
260
261
262
        tool_parser = MistralToolParser(tokenizer)

        model_output = outputs[0].outputs[0].text.strip()
        assert model_output.startswith(tool_parser.bot_token), model_output
        parsed_message = tool_parser.extract_tool_calls(model_output, None)

        assert parsed_message.tools_called
263
264

        assert MistralToolCall.is_valid_id(parsed_message.tool_calls[0].id)
265
266
267
268
269
        assert parsed_message.tool_calls[0].function.name == "get_current_weather"
        assert (
            parsed_message.tool_calls[0].function.arguments
            == '{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}'
        )  # noqa
270
        assert parsed_message.content is None
271
272


273
def test_mistral_function_call_nested_json():
274
    """Ensure that the function-name regex captures the entire outermost
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
    JSON block, including nested braces."""

    # Create a minimal stub tokenizer that provides the few attributes the
    # parser accesses (`version` and `get_vocab`).
    class _StubMistralTokenizer(MistralTokenizer):
        version = 11  # Satisfy the version check

        def __init__(self):
            pass

        @staticmethod
        def get_vocab():
            # Provide the special TOOL_CALLS token expected by the parser.
            return {"[TOOL_CALLS]": 0}

    tokenizer = _StubMistralTokenizer()
    parser = MistralToolParser(tokenizer)

    # Craft a model output featuring nested JSON inside the arguments.
    args_dict = {
        "city": "Dallas",
        "state": "TX",
        "unit": "fahrenheit",
298
        "sub_dict": {"foo": "bar", "inner": {"x": 1, "y": 2}},
299
300
    }

301
    model_output = f"{parser.bot_token}get_current_weather{json.dumps(args_dict)}"
302
303
304
305
306
307
308
309
310
311
312
313

    parsed = parser.extract_tool_calls(model_output, None)

    # Assertions: the tool call is detected and the full nested JSON is parsed
    # without truncation.
    assert parsed.tools_called

    assert MistralToolCall.is_valid_id(parsed.tool_calls[0].id)
    assert parsed.tool_calls[0].function.name == "get_current_weather"
    assert json.loads(parsed.tool_calls[0].function.arguments) == args_dict
    # No additional content outside the tool call should be returned.
    assert parsed.content is None