test_tool_choice_required.py 10.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
import json
from copy import deepcopy
from unittest.mock import MagicMock

import pytest
8
import regex as re
9
10
from pydantic import TypeAdapter

11
12
13
from vllm.entrypoints.openai.protocol import (
    ChatCompletionToolsParam,
)
14
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
15
from vllm.entrypoints.openai.tool_parsers.utils import get_json_schema_from_tools
16

17
18
pytestmark = pytest.mark.cpu_test

19
20
21
22
23
24
25
26
27
28
EXAMPLE_TOOLS = [
    {
        "type": "function",
        "function": {
            "name": "get_current_weather",
            "description": "Get the current weather in a given location",
            "parameters": {
                "type": "object",
                "properties": {
                    "city": {
29
30
                        "type": "string",
                        "description": "The city to find the weather for"
31
32
33
34
                        ", e.g. 'San Francisco'",
                    },
                },
                "required": ["city"],
35
                "additionalProperties": False,
36
37
            },
        },
38
        "strict": True,
39
40
41
42
43
44
45
46
47
48
    },
    {
        "type": "function",
        "function": {
            "name": "get_forecast",
            "description": "Get the weather forecast for a given location",
            "parameters": {
                "type": "object",
                "properties": {
                    "city": {
49
                        "type": "string",
50
51
                        "description": "The city to get the forecast for, e.g. "
                        "'New York'",
52
53
                    },
                    "days": {
54
55
                        "type": "integer",
                        "description": "Number of days to get the forecast for (1-7)",
56
57
58
                    },
                },
                "required": ["city", "days"],
59
                "additionalProperties": False,
60
61
            },
        },
62
        "strict": True,
63
64
65
66
    },
]


67
68
69
def _compile_and_check(
    tools: list[ChatCompletionToolsParam], sample_output, should_match: bool
):
70
71
72
    # self = MagicMock(tool_choice="required", tools=tools)
    # schema = ChatCompletionRequest._get_json_schema_from_tool(self)
    schema = get_json_schema_from_tools(tools=tools, tool_choice="required")
73
74
75
    assert isinstance(schema, dict)

    # use build_regex_from_schema used in JSONLogitsProcessor to create Guide
76
    from outlines_core.json_schema import build_regex_from_schema
77

78
79
80
81
82
83
84
85
    regex = build_regex_from_schema(json.dumps(schema))
    compiled = re.compile(regex)
    matches = compiled.fullmatch(json.dumps(sample_output)) is not None

    assert matches == should_match


VALID_TOOL_OUTPUTS = [
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
    ([{"name": "get_current_weather", "parameters": {"city": "Vienna"}}], True),
    (
        [
            {"name": "get_current_weather", "parameters": {"city": "Vienna"}},
            {"name": "get_current_weather", "parameters": {"city": "Berlin"}},
        ],
        True,
    ),
    ([{"name": "get_forecast", "parameters": {"city": "Vienna", "days": 7}}], True),
    (
        [
            {"name": "get_forecast", "parameters": {"city": "Vienna", "days": 7}},
            {"name": "get_current_weather", "parameters": {"city": "Vienna"}},
        ],
        True,
    ),
    (
        [
            {"name": "get_forecast", "parameters": {"city": "Vienna", "days": 7}},
            {"name": "get_current_weather", "parameters": {"city": "Vienna"}},
            {"name": "get_forecast", "parameters": {"city": "Berlin", "days": 7}},
            {"name": "get_current_weather", "parameters": {"city": "Berlin"}},
        ],
        True,
    ),
111
112
113
114
115
116
117
]

VALID_TOOLS = [t[0] for t in VALID_TOOL_OUTPUTS]


@pytest.mark.parametrize(
    "sample_output, should_match",
118
119
    VALID_TOOL_OUTPUTS
    + [
120
121
122
123
124
        (None, False),
        ([], False),  # empty list cannot be generated
        ({}, False),  # empty object cannot be generated
        ([{}], False),  # list with empty object cannot be generated
        (
125
126
127
128
129
130
131
            [
                {  # function without required parameters cannot be generated
                    "name": "get_current_weather"
                }
            ],
            False,
        ),
132
        (
133
134
135
136
137
138
139
140
            [
                {  # function without required parameters cannot be generated
                    "name": "get_current_weather",
                    "parameters": {},
                }
            ],
            False,
        ),
141
        (
142
143
144
145
146
147
148
149
            [
                {  # function without required parameters cannot be generated
                    "name": "get_current_weather",
                    "parameters": None,
                }
            ],
            False,
        ),
150
151
152
        (
            {  # tool call without lists cannot be generated
                "name": "get_current_weather",
153
                "parameters": {"city": "Vienna"},
154
            },
155
156
            False,
        ),
157
        (
158
159
160
161
            [
                {  # tool call with extra parameters cannot be generated
                    "name": "get_current_weather",
                    "parameters": {"city": "Vienna", "extra": "value"},
162
                }
163
164
165
            ],
            False,
        ),
166
        (
167
168
169
170
171
172
173
174
            [
                {  # tool call where parameters are first cannot be generated
                    "parameters": {"city": "Vienna"},
                    "name": "get_current_weather",
                }
            ],
            False,
        ),
175
        (
176
177
178
179
            [
                {  # tool call without all required parameters cannot be generated
                    "name": "get_forecast",
                    "parameters": {"city": "Vienna"},
180
                }
181
182
183
            ],
            False,
        ),
184
        (  # tool call with incorrect name/parameters cannot be generated
185
186
187
            [{"name": "get_weather", "parameters": {"city": "Vienna", "days": 7}}],
            False,
        ),
188
        (  #  tool call with both valid and empty function cannot be generated
189
190
191
192
193
            [{"name": "get_current_weather", "parameters": {"city": "Vienna"}}, {}],
            False,
        ),
    ],
)
194
def test_structured_outputs_json(sample_output, should_match):
195
196
197
198
199
200
201
    _compile_and_check(
        tools=TypeAdapter(list[ChatCompletionToolsParam]).validate_python(
            EXAMPLE_TOOLS
        ),
        sample_output=sample_output,
        should_match=should_match,
    )
202
203


204
def update_parameters_none(tool: ChatCompletionToolsParam) -> ChatCompletionToolsParam:
205
206
207
208
209
    tool.function.parameters = None
    return tool


def update_parameters_empty_dict(
210
211
    tool: ChatCompletionToolsParam,
) -> ChatCompletionToolsParam:
212
213
214
215
216
217
218
219
220
221
222
223
    tool.function.parameters = {}
    return tool


@pytest.mark.parametrize(
    "sample_output, should_match",
    [
        (None, False),
        ([], False),  # empty list cannot be generated
        ({}, False),  # empty object cannot be generated
        ([{}], False),  # list with empty object cannot be generated
        (
224
225
226
227
228
229
230
            [
                {  # function without required parameters cannot be generated
                    "name": "get_current_weather"
                }
            ],
            False,
        ),
231
        (
232
233
234
235
236
237
238
239
            [
                {  # function without required parameters cannot be generated
                    "name": "get_current_weather",
                    "parameters": None,
                }
            ],
            False,
        ),
240
        (
241
242
243
244
            [
                {  # function with extra parameters cannot be generated
                    "name": "get_current_weather",
                    "parameters": {"extra": "value"},
245
                }
246
247
248
            ],
            False,
        ),
249
        (
250
251
252
253
254
255
256
257
258
259
            [
                {  # only function with empty parameters object is valid
                    "name": "get_current_weather",
                    "parameters": {},
                }
            ],
            True,
        ),
    ],
)
260
@pytest.mark.parametrize(
261
262
263
264
265
    "update_parameters", [update_parameters_none, update_parameters_empty_dict]
)
def test_structured_outputs_json_without_parameters(
    sample_output, should_match, update_parameters
):
266
    updated_tools = [deepcopy(EXAMPLE_TOOLS[0])]
267
    tools = TypeAdapter(list[ChatCompletionToolsParam]).validate_python(updated_tools)
268
    tools = list(map(update_parameters, tools))
269
270
271
272
273
274
275
276
277
    assert all(
        [
            tool.function.parameters is None or tool.function.parameters == {}
            for tool in tools
        ]
    )
    _compile_and_check(
        tools=tools, sample_output=sample_output, should_match=should_match
    )
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294


@pytest.mark.parametrize("output", VALID_TOOLS)
@pytest.mark.parametrize("empty_params", [False, True])
@pytest.mark.parametrize("delta_len", [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
def test_streaming_output_valid(output, empty_params, delta_len):
    self = MagicMock()

    output = deepcopy(output)
    if empty_params:
        output = [{"name": o["name"], "parameters": {}} for o in output]
    output_json = json.dumps(output)

    previous_text = ""
    function_name_returned = False
    messages = []
    for i in range(0, len(output_json), delta_len):
295
        delta_text = output_json[i : i + delta_len]
296
297
298
299
300
301
302
303
        current_text = previous_text + delta_text

        delta_message, function_name_returned = (
            OpenAIServingChat.extract_tool_call_required_streaming(
                self,
                previous_text=previous_text,
                current_text=current_text,
                delta_text=delta_text,
304
305
306
                function_name_returned=function_name_returned,
            )
        )
307
308
309
310
311
312
313
314
315
316
317
318
319

        if delta_message:
            messages.append(delta_message)

        previous_text = current_text

    assert len(messages) > 0
    combined_messages = "["
    for message in messages:
        if message.tool_calls[0].function.name:
            if len(combined_messages) > 1:
                combined_messages += "},"

320
321
322
323
324
325
            combined_messages += (
                '{"name": "'
                + message.tool_calls[0].function.name
                + '", "parameters": '
                + message.tool_calls[0].function.arguments
            )
326
327
328
329
        else:
            combined_messages += message.tool_calls[0].function.arguments
    combined_messages += "}]"
    assert json.loads(combined_messages) == output
330
    assert json.dumps(json.loads(combined_messages)) == output_json