test_tool_choice_required.py 11.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
import json
from copy import deepcopy
from unittest.mock import MagicMock

import pytest
8
import regex as re
9
10
from pydantic import TypeAdapter

11
from vllm.entrypoints.openai.chat_completion.protocol import (
12
13
    ChatCompletionToolsParam,
)
14
from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat
15
from vllm.tool_parsers.utils import get_json_schema_from_tools
16

17
18
pytestmark = pytest.mark.cpu_test

19
20
21
22
23
24
25
26
27
28
EXAMPLE_TOOLS = [
    {
        "type": "function",
        "function": {
            "name": "get_current_weather",
            "description": "Get the current weather in a given location",
            "parameters": {
                "type": "object",
                "properties": {
                    "city": {
29
30
                        "type": "string",
                        "description": "The city to find the weather for"
31
32
33
34
                        ", e.g. 'San Francisco'",
                    },
                },
                "required": ["city"],
35
                "additionalProperties": False,
36
37
            },
        },
38
        "strict": True,
39
40
41
42
43
44
45
46
47
48
    },
    {
        "type": "function",
        "function": {
            "name": "get_forecast",
            "description": "Get the weather forecast for a given location",
            "parameters": {
                "type": "object",
                "properties": {
                    "city": {
49
                        "type": "string",
50
51
                        "description": "The city to get the forecast for, e.g. "
                        "'New York'",
52
53
                    },
                    "days": {
54
55
                        "type": "integer",
                        "description": "Number of days to get the forecast for (1-7)",
56
57
58
                    },
                },
                "required": ["city", "days"],
59
                "additionalProperties": False,
60
61
            },
        },
62
        "strict": True,
63
64
65
66
    },
]


67
68
69
def _compile_and_check(
    tools: list[ChatCompletionToolsParam], sample_output, should_match: bool
):
70
71
72
    # self = MagicMock(tool_choice="required", tools=tools)
    # schema = ChatCompletionRequest._get_json_schema_from_tool(self)
    schema = get_json_schema_from_tools(tools=tools, tool_choice="required")
73
74
75
    assert isinstance(schema, dict)

    # use build_regex_from_schema used in JSONLogitsProcessor to create Guide
76
    from outlines_core.json_schema import build_regex_from_schema
77

78
79
80
81
82
83
84
85
    regex = build_regex_from_schema(json.dumps(schema))
    compiled = re.compile(regex)
    matches = compiled.fullmatch(json.dumps(sample_output)) is not None

    assert matches == should_match


VALID_TOOL_OUTPUTS = [
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
    ([{"name": "get_current_weather", "parameters": {"city": "Vienna"}}], True),
    (
        [
            {"name": "get_current_weather", "parameters": {"city": "Vienna"}},
            {"name": "get_current_weather", "parameters": {"city": "Berlin"}},
        ],
        True,
    ),
    ([{"name": "get_forecast", "parameters": {"city": "Vienna", "days": 7}}], True),
    (
        [
            {"name": "get_forecast", "parameters": {"city": "Vienna", "days": 7}},
            {"name": "get_current_weather", "parameters": {"city": "Vienna"}},
        ],
        True,
    ),
    (
        [
            {"name": "get_forecast", "parameters": {"city": "Vienna", "days": 7}},
            {"name": "get_current_weather", "parameters": {"city": "Vienna"}},
            {"name": "get_forecast", "parameters": {"city": "Berlin", "days": 7}},
            {"name": "get_current_weather", "parameters": {"city": "Berlin"}},
        ],
        True,
    ),
111
112
113
114
115
116
117
]

VALID_TOOLS = [t[0] for t in VALID_TOOL_OUTPUTS]


@pytest.mark.parametrize(
    "sample_output, should_match",
118
119
    VALID_TOOL_OUTPUTS
    + [
120
121
122
123
124
        (None, False),
        ([], False),  # empty list cannot be generated
        ({}, False),  # empty object cannot be generated
        ([{}], False),  # list with empty object cannot be generated
        (
125
126
127
128
129
130
131
            [
                {  # function without required parameters cannot be generated
                    "name": "get_current_weather"
                }
            ],
            False,
        ),
132
        (
133
134
135
136
137
138
139
140
            [
                {  # function without required parameters cannot be generated
                    "name": "get_current_weather",
                    "parameters": {},
                }
            ],
            False,
        ),
141
        (
142
143
144
145
146
147
148
149
            [
                {  # function without required parameters cannot be generated
                    "name": "get_current_weather",
                    "parameters": None,
                }
            ],
            False,
        ),
150
151
152
        (
            {  # tool call without lists cannot be generated
                "name": "get_current_weather",
153
                "parameters": {"city": "Vienna"},
154
            },
155
156
            False,
        ),
157
        (
158
159
160
161
            [
                {  # tool call with extra parameters cannot be generated
                    "name": "get_current_weather",
                    "parameters": {"city": "Vienna", "extra": "value"},
162
                }
163
164
165
            ],
            False,
        ),
166
        (
167
168
169
170
171
172
173
174
            [
                {  # tool call where parameters are first cannot be generated
                    "parameters": {"city": "Vienna"},
                    "name": "get_current_weather",
                }
            ],
            False,
        ),
175
        (
176
177
178
179
            [
                {  # tool call without all required parameters cannot be generated
                    "name": "get_forecast",
                    "parameters": {"city": "Vienna"},
180
                }
181
182
183
            ],
            False,
        ),
184
        (  # tool call with incorrect name/parameters cannot be generated
185
186
187
            [{"name": "get_weather", "parameters": {"city": "Vienna", "days": 7}}],
            False,
        ),
188
        (  #  tool call with both valid and empty function cannot be generated
189
190
191
192
193
            [{"name": "get_current_weather", "parameters": {"city": "Vienna"}}, {}],
            False,
        ),
    ],
)
194
def test_structured_outputs_json(sample_output, should_match):
195
196
197
198
199
200
201
    _compile_and_check(
        tools=TypeAdapter(list[ChatCompletionToolsParam]).validate_python(
            EXAMPLE_TOOLS
        ),
        sample_output=sample_output,
        should_match=should_match,
    )
202
203


204
def update_parameters_none(tool: ChatCompletionToolsParam) -> ChatCompletionToolsParam:
205
206
207
208
209
    tool.function.parameters = None
    return tool


def update_parameters_empty_dict(
210
211
    tool: ChatCompletionToolsParam,
) -> ChatCompletionToolsParam:
212
213
214
215
216
217
218
219
220
221
222
223
    tool.function.parameters = {}
    return tool


@pytest.mark.parametrize(
    "sample_output, should_match",
    [
        (None, False),
        ([], False),  # empty list cannot be generated
        ({}, False),  # empty object cannot be generated
        ([{}], False),  # list with empty object cannot be generated
        (
224
225
226
227
228
229
230
            [
                {  # function without required parameters cannot be generated
                    "name": "get_current_weather"
                }
            ],
            False,
        ),
231
        (
232
233
234
235
236
237
238
239
            [
                {  # function without required parameters cannot be generated
                    "name": "get_current_weather",
                    "parameters": None,
                }
            ],
            False,
        ),
240
        (
241
242
243
244
            [
                {  # function with extra parameters cannot be generated
                    "name": "get_current_weather",
                    "parameters": {"extra": "value"},
245
                }
246
247
248
            ],
            False,
        ),
249
        (
250
251
252
253
254
255
256
257
258
259
            [
                {  # only function with empty parameters object is valid
                    "name": "get_current_weather",
                    "parameters": {},
                }
            ],
            True,
        ),
    ],
)
260
@pytest.mark.parametrize(
261
262
263
264
265
    "update_parameters", [update_parameters_none, update_parameters_empty_dict]
)
def test_structured_outputs_json_without_parameters(
    sample_output, should_match, update_parameters
):
266
    updated_tools = [deepcopy(EXAMPLE_TOOLS[0])]
267
    tools = TypeAdapter(list[ChatCompletionToolsParam]).validate_python(updated_tools)
268
    tools = list(map(update_parameters, tools))
269
270
271
272
273
274
275
276
277
    assert all(
        [
            tool.function.parameters is None or tool.function.parameters == {}
            for tool in tools
        ]
    )
    _compile_and_check(
        tools=tools, sample_output=sample_output, should_match=should_match
    )
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294


@pytest.mark.parametrize("output", VALID_TOOLS)
@pytest.mark.parametrize("empty_params", [False, True])
@pytest.mark.parametrize("delta_len", [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
def test_streaming_output_valid(output, empty_params, delta_len):
    self = MagicMock()

    output = deepcopy(output)
    if empty_params:
        output = [{"name": o["name"], "parameters": {}} for o in output]
    output_json = json.dumps(output)

    previous_text = ""
    function_name_returned = False
    messages = []
    for i in range(0, len(output_json), delta_len):
295
        delta_text = output_json[i : i + delta_len]
296
297
298
299
300
301
302
303
        current_text = previous_text + delta_text

        delta_message, function_name_returned = (
            OpenAIServingChat.extract_tool_call_required_streaming(
                self,
                previous_text=previous_text,
                current_text=current_text,
                delta_text=delta_text,
304
305
306
                function_name_returned=function_name_returned,
            )
        )
307
308
309
310
311
312
313

        if delta_message:
            messages.append(delta_message)

        previous_text = current_text

    assert len(messages) > 0
314

315
316
317
318
319
320
    combined_messages = "["
    for message in messages:
        if message.tool_calls[0].function.name:
            if len(combined_messages) > 1:
                combined_messages += "},"

321
322
323
324
325
326
            combined_messages += (
                '{"name": "'
                + message.tool_calls[0].function.name
                + '", "parameters": '
                + message.tool_calls[0].function.arguments
            )
327
328
329
330
        else:
            combined_messages += message.tool_calls[0].function.arguments
    combined_messages += "}]"
    assert json.loads(combined_messages) == output
331
    assert json.dumps(json.loads(combined_messages)) == output_json
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363


def test_streaming_output_valid_with_trailing_extra_data():
    self = MagicMock()

    output = [{"name": "get_current_weather", "parameters": {"city": "Vienna"}}]
    output_json = json.dumps(output) + "\nDONE"

    previous_text = ""
    function_name_returned = False
    messages = []
    delta_len = 3
    for i in range(0, len(output_json), delta_len):
        delta_text = output_json[i : i + delta_len]
        current_text = previous_text + delta_text

        delta_message, function_name_returned = (
            OpenAIServingChat.extract_tool_call_required_streaming(
                self,
                previous_text=previous_text,
                current_text=current_text,
                delta_text=delta_text,
                function_name_returned=function_name_returned,
            )
        )

        if delta_message:
            messages.append(delta_message)

        previous_text = current_text

    assert len(messages) > 0