utils.py 7.29 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import json
from json import JSONDecodeError, JSONDecoder
6
from typing import Any
7
8

import partial_json_parser
9
10
11
12
13
from openai.types.responses import (
    FunctionTool,
    ToolChoiceFunction,
)
from openai.types.responses.tool import Tool
14
15
from partial_json_parser.core.options import Allow

16
from vllm.entrypoints.openai.chat_completion.protocol import (
17
18
19
20
    ChatCompletionNamedToolChoiceParam,
    ChatCompletionToolsParam,
)

21

22
23
24
25
26
27
28
29
30
31
32
33
34
def find_common_prefix(s1: str, s2: str) -> str:
    """
    Finds a common prefix that is shared between two strings, if there is one.
    Order of arguments is NOT important.

    This function is provided as a UTILITY for extracting information from JSON
    generated by partial_json_parser, to help in ensuring that the right tokens
    are returned in streaming, so that close-quotes, close-brackets and
    close-braces are not returned prematurely.

    e.g. find_common_prefix('{"fruit": "ap"}', '{"fruit": "apple"}') ->
    '{"fruit": "ap'
    """
35
    prefix = ""
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
    min_length = min(len(s1), len(s2))
    for i in range(0, min_length):
        if s1[i] == s2[i]:
            prefix += s1[i]
        else:
            break
    return prefix


def find_common_suffix(s1: str, s2: str) -> str:
    """
    Finds a common suffix shared between two strings, if there is one. Order of
    arguments is NOT important.
    Stops when the suffix ends OR it hits an alphanumeric character

    e.g. find_common_suffix('{"fruit": "ap"}', '{"fruit": "apple"}') -> '"}'
    """
53
    suffix = ""
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
    min_length = min(len(s1), len(s2))
    for i in range(1, min_length + 1):
        if s1[-i] == s2[-i] and not s1[-i].isalnum():
            suffix = s1[-i] + suffix
        else:
            break
    return suffix


def extract_intermediate_diff(curr: str, old: str) -> str:
    """
    Given two strings, extract the difference in the middle between two strings
    that are known to have a common prefix and/or suffix.

    This function is provided as a UTILITY for extracting information from JSON
    generated by partial_json_parser, to help in ensuring that the right tokens
    are returned in streaming, so that close-quotes, close-brackets and
    close-braces are not returned prematurely. The order of arguments IS
    important - the new version of the partially-parsed JSON must be the first
    argument, and the secnod argument must be from the previous generation.

    What it returns, is tokens that should be streamed to the client.

    e.g. extract_intermediate_diff('{"fruit": "apple"}', '{"fruit": "ap"}')
        -> 'ple'

    """
    suffix = find_common_suffix(curr, old)

83
    old = old[::-1].replace(suffix[::-1], "", 1)[::-1]
84
85
86
    prefix = find_common_prefix(curr, old)
    diff = curr
    if len(suffix):
87
        diff = diff[::-1].replace(suffix[::-1], "", 1)[::-1]
88
89
90

    if len(prefix):
        # replace the prefix only once in case it's mirrored
91
        diff = diff.replace(prefix, "", 1)
92
93
94
95

    return diff


96
def find_all_indices(string: str, substring: str) -> list[int]:
97
98
99
100
101
102
103
104
105
106
107
108
    """
    Find all (starting) indices of a substring in a given string. Useful for
    tool call extraction
    """
    indices = []
    index = -1
    while True:
        index = string.find(substring, index + 1)
        if index == -1:
            break
        indices.append(index)
    return indices
109
110
111


# partial_json_parser doesn't support extra data and
112
# JSONDecoder.raw_decode doesn't support partial JSON
113
def partial_json_loads(input_str: str, flags: Allow) -> tuple[Any, int]:
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
    try:
        return (partial_json_parser.loads(input_str, flags), len(input_str))
    except JSONDecodeError as e:
        if "Extra data" in e.msg:
            dec = JSONDecoder()
            return dec.raw_decode(input_str)
        raise


def is_complete_json(input_str: str) -> bool:
    try:
        json.loads(input_str)
        return True
    except JSONDecodeError:
        return False


def consume_space(i: int, s: str) -> int:
    while i < len(s) and s[i].isspace():
        i += 1
    return i
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229


def _extract_tool_info(
    tool: Tool | ChatCompletionToolsParam,
) -> tuple[str, dict[str, Any] | None]:
    if isinstance(tool, FunctionTool):
        return tool.name, tool.parameters
    elif isinstance(tool, ChatCompletionToolsParam):
        return tool.function.name, tool.function.parameters
    else:
        raise TypeError(f"Unsupported tool type: {type(tool)}")


def _get_tool_schema_from_tool(tool: Tool | ChatCompletionToolsParam) -> dict:
    name, params = _extract_tool_info(tool)
    params = params if params else {"type": "object", "properties": {}}
    return {
        "properties": {
            "name": {"type": "string", "enum": [name]},
            "parameters": params,
        },
        "required": ["name", "parameters"],
    }


def _get_tool_schema_defs(
    tools: list[Tool | ChatCompletionToolsParam],
) -> dict:
    all_defs: dict[str, dict[str, Any]] = {}
    for tool in tools:
        _, params = _extract_tool_info(tool)
        if params is None:
            continue
        defs = params.pop("$defs", {})
        for def_name, def_schema in defs.items():
            if def_name in all_defs and all_defs[def_name] != def_schema:
                raise ValueError(
                    f"Tool definition '{def_name}' has multiple schemas, "
                    "which is not supported."
                )
            all_defs[def_name] = def_schema
    return all_defs


def _get_json_schema_from_tools(
    tools: list[Tool | ChatCompletionToolsParam],
) -> dict:
    json_schema = {
        "type": "array",
        "minItems": 1,
        "items": {
            "type": "object",
            "anyOf": [_get_tool_schema_from_tool(tool) for tool in tools],
        },
    }
    json_schema_defs = _get_tool_schema_defs(tools)
    if json_schema_defs:
        json_schema["$defs"] = json_schema_defs
    return json_schema


def get_json_schema_from_tools(
    tool_choice: str | ToolChoiceFunction | ChatCompletionNamedToolChoiceParam,
    tools: list[FunctionTool | ChatCompletionToolsParam] | None,
) -> str | dict | None:
    # tool_choice: "none"
    if tool_choice in ("none", None) or tools is None:
        return None
    # tool_choice: Forced Function (Responses)
    if (not isinstance(tool_choice, str)) and isinstance(
        tool_choice, ToolChoiceFunction
    ):
        tool_name = tool_choice.name
        tool_map = {tool.name: tool for tool in tools if isinstance(tool, FunctionTool)}
        if tool_name not in tool_map:
            raise ValueError(f"Tool '{tool_name}' has not been passed in `tools`.")
        return tool_map[tool_name].parameters
    # tool_choice: Forced Function (ChatCompletion)
    if (not isinstance(tool_choice, str)) and isinstance(
        tool_choice, ChatCompletionNamedToolChoiceParam
    ):
        tool_name = tool_choice.function.name
        tool_map = {
            tool.function.name: tool
            for tool in tools
            if isinstance(tool, ChatCompletionToolsParam)
        }
        if tool_name not in tool_map:
            raise ValueError(f"Tool '{tool_name}' has not been passed in `tools`.")
        return tool_map[tool_name].function.parameters
    # tool_choice: "required"
    if tool_choice == "required":
        return _get_json_schema_from_tools(tools)
    # tool_choice: "auto"
    return None