test_mistral.py 11 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import copy
4
import json
5

6
7
import pytest

8
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import (
9
10
11
    MistralToolCall,
    MistralToolParser,
)
12
from vllm.sampling_params import SamplingParams
13
from vllm.tokenizers import MistralTokenizer
14

15
from ...utils import check_logprobs_close
16

17
MODELS = [
18
    "mistralai/Mistral-7B-Instruct-v0.3",
19
20
]

21
22
MISTRAL_FORMAT_MODELS = [
    "mistralai/Mistral-7B-Instruct-v0.3",
23
24
    # uses the v3-Tekken tokenizer
    "mistralai/Ministral-8B-Instruct-2410",
25
    # Mistral-Nemo is too big for CI, but passes locally
26
    # "mistralai/Mistral-Nemo-Instruct-2407"
27
28
]

29
SAMPLING_PARAMS = SamplingParams(max_tokens=512, temperature=0.0, logprobs=5)
30
31
32
SYMBOLIC_LANG_PROMPTS = [
    "勇敢な船乗りについての詩を書く",  # japanese
    "寫一首關於勇敢的水手的詩",  # chinese
33
34
    "ပုံပြင်လေးပြောပြပါ်:\n",  # burmese
    "Repeat the phrase 'URGENCY🌶️':\nURGENCY🌶️\nURGENCY🌶️\n",  # see https://github.com/vllm-project/vllm/pull/9625
35
]
36
37

# for function calling
38
39
40
41
42
43
44
45
46
47
48
TOOLS = [
    {
        "type": "function",
        "function": {
            "name": "get_current_weather",
            "description": "Get the current weather in a given location",
            "parameters": {
                "type": "object",
                "properties": {
                    "city": {
                        "type": "string",
49
50
                        "description": "The city to find the weather for, e.g. "
                        "'San Francisco'",
51
52
53
                    },
                    "state": {
                        "type": "string",
54
55
                        "description": "the two-letter abbreviation for the state that "
                        "the city is in, e.g. 'CA' which would mean 'California'",
56
57
58
59
60
61
                    },
                    "unit": {
                        "type": "string",
                        "description": "The unit to fetch the temperature in",
                        "enum": ["celsius", "fahrenheit"],
                    },
62
                },
63
                "required": ["city", "state", "unit"],
64
            },
65
        },
66
67
    },
    {
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
        "type": "function",
        "function": {
            "name": "rewrite",
            "description": "Rewrites text",
            "parameters": {
                "type": "object",
                "required": [],
                "properties": {
                    "text": {
                        "type": "string",
                        "description": "The input text to rewrite.",
                    }
                },
            },
        },
83
    },
84
85
86
]
MSGS = [
    {"role": "system", "content": "You are an assistant."},
87
    {
88
        "role": "user",
89
90
        "content": "Could you please rewrite the below article? \n\n My English needs "
        "improvving, maybe I make errors.",
91
92
    },
    {
93
94
95
96
97
98
99
100
        "role": "assistant",
        "content": "",
        "tool_calls": [
            {
                "id": "bbc5b7ede",
                "type": "function",
                "function": {
                    "name": "rewrite",
101
102
                    "arguments": '{"text":"My English needs improvving, maybe '
                    'I make errors."}',
103
                },
104
            }
105
        ],
106
107
108
    },
    {
        "role": "tool",
109
110
        "content": '{"action":"rewrite","outcome":"My English needs improving, maybe '
        'I make errors."}',
111
        "tool_call_id": "bbc5b7ede",
112
        "name": "rewrite",
113
114
115
    },
    {
        "role": "assistant",
116
        "content": "---\n\nMy English needs improving, maybe I make errors",
117
118
    },
    {
119
120
121
122
123
        "role": "user",
        "content": (
            "Can you tell me what the temperate will be in Dallas, in fahrenheit?"
        ),
    },
124
]
125

126
127
128
SAMPLE_JSON_SCHEMA = {
    "type": "object",
    "properties": {
129
130
        "name": {"type": "string"},
        "age": {"type": "integer"},
131
132
        "skills": {
            "type": "array",
133
134
            "items": {"type": "string", "maxLength": 10},
            "minItems": 3,
135
136
137
138
139
140
        },
        "work_history": {
            "type": "array",
            "items": {
                "type": "object",
                "properties": {
141
142
143
                    "company": {"type": "string"},
                    "duration": {"type": "number"},
                    "position": {"type": "string"},
144
                },
145
146
147
                "required": ["company", "position"],
            },
        },
148
    },
149
    "required": ["name", "age", "skills", "work_history"],
150
151
}

152
153
154

@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
155
156
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
157
158
159
160
161
162
163
164
165
def test_models(
    hf_runner,
    vllm_runner,
    example_prompts,
    model: str,
    dtype: str,
    max_tokens: int,
    num_logprobs: int,
) -> None:
166
    # TODO(sang): Sliding window should be tested separately.
167
168
    with hf_runner(model, dtype=dtype) as hf_model:
        hf_outputs = hf_model.generate_greedy_logprobs_limit(
169
170
            example_prompts, max_tokens, num_logprobs
        )
171

172
    with vllm_runner(model, dtype=dtype, tokenizer_mode="mistral") as vllm_model:
173
        vllm_outputs = vllm_model.generate_greedy_logprobs(
174
175
            example_prompts, max_tokens, num_logprobs
        )
176

177
178
179
180
181
182
    check_logprobs_close(
        outputs_0_lst=hf_outputs,
        outputs_1_lst=vllm_outputs,
        name_0="hf",
        name_1="vllm",
    )
183
184


185
@pytest.mark.parametrize("model", MISTRAL_FORMAT_MODELS)
186
187
188
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
189
190
191
192
193
194
195
196
def test_mistral_format(
    vllm_runner,
    example_prompts,
    model: str,
    dtype: str,
    max_tokens: int,
    num_logprobs: int,
) -> None:
197
    with vllm_runner(
198
199
200
201
202
        model,
        dtype=dtype,
        tokenizer_mode="mistral",
        load_format="mistral",
        config_format="mistral",
203
204
    ) as mistral_format_model:
        mistral_format_outputs = mistral_format_model.generate_greedy_logprobs(
205
206
            example_prompts, max_tokens, num_logprobs
        )
207

208
    with vllm_runner(
209
210
        model,
        dtype=dtype,
211
        tokenizer_mode="hf",
212
213
        load_format="safetensors",
        config_format="hf",
214
215
    ) as hf_format_model:
        hf_format_outputs = hf_format_model.generate_greedy_logprobs(
216
217
            example_prompts, max_tokens, num_logprobs
        )
218

219
220
221
222
223
224
    check_logprobs_close(
        outputs_0_lst=hf_format_outputs,
        outputs_1_lst=mistral_format_outputs,
        name_0="hf",
        name_1="mistral",
    )
225
226


227
@pytest.mark.parametrize("model", MISTRAL_FORMAT_MODELS)
228
@pytest.mark.parametrize("dtype", ["bfloat16"])
229
230
231
232
233
234
235
236
237
def test_mistral_symbolic_languages(vllm_runner, model: str, dtype: str) -> None:
    with vllm_runner(
        model,
        dtype=dtype,
        max_model_len=8192,
        tokenizer_mode="mistral",
        config_format="mistral",
        load_format="mistral",
    ) as vllm_model:
238
239
        for prompt in SYMBOLIC_LANG_PROMPTS:
            msg = {"role": "user", "content": prompt}
240
            outputs = vllm_model.llm.chat([msg], sampling_params=SAMPLING_PARAMS)
241
            assert "�" not in outputs[0].outputs[0].text.strip()
242
243


244
@pytest.mark.parametrize("model", MISTRAL_FORMAT_MODELS)
245
@pytest.mark.parametrize("dtype", ["bfloat16"])
246
def test_mistral_function_calling(vllm_runner, model: str, dtype: str) -> None:
247
248
249
250
251
252
253
    with vllm_runner(
        model,
        dtype=dtype,
        tokenizer_mode="mistral",
        config_format="mistral",
        load_format="mistral",
    ) as vllm_model:
254
        msgs = copy.deepcopy(MSGS)
255
256
257
        outputs = vllm_model.llm.chat(
            msgs, tools=TOOLS, sampling_params=SAMPLING_PARAMS
        )
258

259
        tokenizer = vllm_model.llm.get_tokenizer()
260
261
262
263
264
265
266
        tool_parser = MistralToolParser(tokenizer)

        model_output = outputs[0].outputs[0].text.strip()
        assert model_output.startswith(tool_parser.bot_token), model_output
        parsed_message = tool_parser.extract_tool_calls(model_output, None)

        assert parsed_message.tools_called
267
268

        assert MistralToolCall.is_valid_id(parsed_message.tool_calls[0].id)
269
270
271
272
273
        assert parsed_message.tool_calls[0].function.name == "get_current_weather"
        assert (
            parsed_message.tool_calls[0].function.arguments
            == '{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}'
        )  # noqa
274
        assert parsed_message.content is None
275
276


277
def test_mistral_function_call_nested_json():
278
    """Ensure that the function-name regex captures the entire outermost
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
    JSON block, including nested braces."""

    # Create a minimal stub tokenizer that provides the few attributes the
    # parser accesses (`version` and `get_vocab`).
    class _StubMistralTokenizer(MistralTokenizer):
        version = 11  # Satisfy the version check

        def __init__(self):
            pass

        @staticmethod
        def get_vocab():
            # Provide the special TOOL_CALLS token expected by the parser.
            return {"[TOOL_CALLS]": 0}

    tokenizer = _StubMistralTokenizer()
    parser = MistralToolParser(tokenizer)

    # Craft a model output featuring nested JSON inside the arguments.
    args_dict = {
        "city": "Dallas",
        "state": "TX",
        "unit": "fahrenheit",
302
        "sub_dict": {"foo": "bar", "inner": {"x": 1, "y": 2}},
303
304
    }

305
    model_output = f"{parser.bot_token}get_current_weather{json.dumps(args_dict)}"
306
307
308
309
310
311
312
313
314
315
316
317

    parsed = parser.extract_tool_calls(model_output, None)

    # Assertions: the tool call is detected and the full nested JSON is parsed
    # without truncation.
    assert parsed.tools_called

    assert MistralToolCall.is_valid_id(parsed.tool_calls[0].id)
    assert parsed.tool_calls[0].function.name == "get_current_weather"
    assert json.loads(parsed.tool_calls[0].function.arguments) == args_dict
    # No additional content outside the tool call should be returned.
    assert parsed.content is None
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352

    # multiple calls
    multiple_args_dict = [
        {
            "city": "Dallas",
            "state": "TX",
            "unit": "fahrenheit",
            "sub_dict": {"foo": "bar", "inner": {"x": 1, "y": 2}},
        },
        {},
        {"a": 0},
        {"a": 1, "b": "c"},
    ]
    names = ["get_current_weather", "get_current_weather_2", "random", "random_2"]

    model_output = "".join(
        [
            f"{parser.bot_token}{name}{json.dumps(args)}"
            for name, args in zip(names, multiple_args_dict)
        ]
    )

    parsed = parser.extract_tool_calls(model_output, None)

    # Assertions: the tool call is detected and the full nested JSON is parsed
    # without truncation.
    assert parsed.tools_called
    assert len(parsed.tool_calls) == len(multiple_args_dict)

    for i, tool_call in enumerate(parsed.tool_calls):
        assert MistralToolCall.is_valid_id(tool_call.id)
        assert tool_call.function.name == names[i]
        assert json.loads(tool_call.function.arguments) == multiple_args_dict[i]
        # No additional content outside the tool call should be returned.
        assert parsed.content is None