utils.py 14.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import ast
5
6
import json
from json import JSONDecodeError, JSONDecoder
7
from typing import Any, TypeAlias
8
9

import partial_json_parser
10
11
12
13
from openai.types.responses import (
    FunctionTool,
    ToolChoiceFunction,
)
14
from openai.types.responses.tool import Tool as ResponsesTool
15
16
from partial_json_parser.core.options import Allow

17
from vllm.entrypoints.openai.chat_completion.protocol import (
18
19
20
    ChatCompletionNamedToolChoiceParam,
    ChatCompletionToolsParam,
)
21
22
23
24
25
26
27
28
from vllm.entrypoints.openai.engine.protocol import (
    DeltaFunctionCall,
    DeltaToolCall,
    FunctionCall,
    ToolCall,
)
from vllm.logger import init_logger

29
30
Tool: TypeAlias = ChatCompletionToolsParam | ResponsesTool

31
logger = init_logger(__name__)
32

33

34
35
36
37
38
39
40
41
42
43
44
45
46
def partial_tag_overlap(text: str, tag: str) -> int:
    """Length of the longest prefix of *tag* that matches a suffix of *text*.

    E.g. text ending in ``"<tool_"`` returns 6 when tag is ``"<tool_call>"``.
    Returns 0 when there is no overlap.
    """
    max_check = min(len(tag) - 1, len(text))
    for k in range(max_check, 0, -1):
        if text.endswith(tag[:k]):
            return k
    return 0


47
48
49
50
51
52
53
54
55
56
57
58
59
def find_common_prefix(s1: str, s2: str) -> str:
    """
    Finds a common prefix that is shared between two strings, if there is one.
    Order of arguments is NOT important.

    This function is provided as a UTILITY for extracting information from JSON
    generated by partial_json_parser, to help in ensuring that the right tokens
    are returned in streaming, so that close-quotes, close-brackets and
    close-braces are not returned prematurely.

    e.g. find_common_prefix('{"fruit": "ap"}', '{"fruit": "apple"}') ->
    '{"fruit": "ap'
    """
60
    prefix = ""
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
    min_length = min(len(s1), len(s2))
    for i in range(0, min_length):
        if s1[i] == s2[i]:
            prefix += s1[i]
        else:
            break
    return prefix


def find_common_suffix(s1: str, s2: str) -> str:
    """
    Finds a common suffix shared between two strings, if there is one. Order of
    arguments is NOT important.
    Stops when the suffix ends OR it hits an alphanumeric character

    e.g. find_common_suffix('{"fruit": "ap"}', '{"fruit": "apple"}') -> '"}'
    """
78
    suffix = ""
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
    min_length = min(len(s1), len(s2))
    for i in range(1, min_length + 1):
        if s1[-i] == s2[-i] and not s1[-i].isalnum():
            suffix = s1[-i] + suffix
        else:
            break
    return suffix


def extract_intermediate_diff(curr: str, old: str) -> str:
    """
    Given two strings, extract the difference in the middle between two strings
    that are known to have a common prefix and/or suffix.

    This function is provided as a UTILITY for extracting information from JSON
    generated by partial_json_parser, to help in ensuring that the right tokens
    are returned in streaming, so that close-quotes, close-brackets and
    close-braces are not returned prematurely. The order of arguments IS
    important - the new version of the partially-parsed JSON must be the first
    argument, and the secnod argument must be from the previous generation.

    What it returns, is tokens that should be streamed to the client.

    e.g. extract_intermediate_diff('{"fruit": "apple"}', '{"fruit": "ap"}')
        -> 'ple'

    """
    suffix = find_common_suffix(curr, old)

108
    old = old[::-1].replace(suffix[::-1], "", 1)[::-1]
109
110
111
    prefix = find_common_prefix(curr, old)
    diff = curr
    if len(suffix):
112
        diff = diff[::-1].replace(suffix[::-1], "", 1)[::-1]
113
114
115

    if len(prefix):
        # replace the prefix only once in case it's mirrored
116
        diff = diff.replace(prefix, "", 1)
117
118
119
120

    return diff


121
# partial_json_parser doesn't support extra data and
122
# JSONDecoder.raw_decode doesn't support partial JSON
123
def partial_json_loads(input_str: str, flags: Allow) -> tuple[Any, int]:
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
    try:
        return (partial_json_parser.loads(input_str, flags), len(input_str))
    except JSONDecodeError as e:
        if "Extra data" in e.msg:
            dec = JSONDecoder()
            return dec.raw_decode(input_str)
        raise


def is_complete_json(input_str: str) -> bool:
    try:
        json.loads(input_str)
        return True
    except JSONDecodeError:
        return False


def consume_space(i: int, s: str) -> int:
    while i < len(s) and s[i].isspace():
        i += 1
    return i
145
146
147


def _extract_tool_info(
148
    tool: Tool,
149
150
151
152
153
154
155
156
157
) -> tuple[str, dict[str, Any] | None]:
    if isinstance(tool, FunctionTool):
        return tool.name, tool.parameters
    elif isinstance(tool, ChatCompletionToolsParam):
        return tool.function.name, tool.function.parameters
    else:
        raise TypeError(f"Unsupported tool type: {type(tool)}")


158
159
160
161
162
163
164
165
166
167
168
169
170
171
def find_tool_properties(
    tools: list[Tool] | None,
    tool_name: str,
) -> dict[str, Any]:
    """Find a tool by name and return its properties dict, or {}."""
    if not tools:
        return {}
    for tool in tools:
        name, params = _extract_tool_info(tool)
        if name == tool_name:
            return (params or {}).get("properties", {})
    return {}


172
def _get_tool_schema_from_tool(tool: Tool) -> dict:
173
174
175
176
177
178
179
180
181
182
183
184
    name, params = _extract_tool_info(tool)
    params = params if params else {"type": "object", "properties": {}}
    return {
        "properties": {
            "name": {"type": "string", "enum": [name]},
            "parameters": params,
        },
        "required": ["name", "parameters"],
    }


def _get_tool_schema_defs(
185
    tools: list[Tool],
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
) -> dict:
    all_defs: dict[str, dict[str, Any]] = {}
    for tool in tools:
        _, params = _extract_tool_info(tool)
        if params is None:
            continue
        defs = params.pop("$defs", {})
        for def_name, def_schema in defs.items():
            if def_name in all_defs and all_defs[def_name] != def_schema:
                raise ValueError(
                    f"Tool definition '{def_name}' has multiple schemas, "
                    "which is not supported."
                )
            all_defs[def_name] = def_schema
    return all_defs


def _get_json_schema_from_tools(
204
    tools: list[Tool],
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
) -> dict:
    json_schema = {
        "type": "array",
        "minItems": 1,
        "items": {
            "type": "object",
            "anyOf": [_get_tool_schema_from_tool(tool) for tool in tools],
        },
    }
    json_schema_defs = _get_tool_schema_defs(tools)
    if json_schema_defs:
        json_schema["$defs"] = json_schema_defs
    return json_schema


def get_json_schema_from_tools(
    tool_choice: str | ToolChoiceFunction | ChatCompletionNamedToolChoiceParam,
222
    tools: list[Tool] | None,
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
) -> str | dict | None:
    # tool_choice: "none"
    if tool_choice in ("none", None) or tools is None:
        return None
    # tool_choice: Forced Function (Responses)
    if (not isinstance(tool_choice, str)) and isinstance(
        tool_choice, ToolChoiceFunction
    ):
        tool_name = tool_choice.name
        tool_map = {tool.name: tool for tool in tools if isinstance(tool, FunctionTool)}
        if tool_name not in tool_map:
            raise ValueError(f"Tool '{tool_name}' has not been passed in `tools`.")
        return tool_map[tool_name].parameters
    # tool_choice: Forced Function (ChatCompletion)
    if (not isinstance(tool_choice, str)) and isinstance(
        tool_choice, ChatCompletionNamedToolChoiceParam
    ):
        tool_name = tool_choice.function.name
        tool_map = {
            tool.function.name: tool
            for tool in tools
            if isinstance(tool, ChatCompletionToolsParam)
        }
        if tool_name not in tool_map:
            raise ValueError(f"Tool '{tool_name}' has not been passed in `tools`.")
        return tool_map[tool_name].function.parameters
    # tool_choice: "required"
    if tool_choice == "required":
        return _get_json_schema_from_tools(tools)
    # tool_choice: "auto"
    return None
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452


# ---------------------------------------------------------------------------
# Shared utilities for pythonic-style tool call parsers
# (PythonicToolParser, Llama4PythonicToolParser, Olmo3PythonicToolParser)
# ---------------------------------------------------------------------------


class UnexpectedAstError(Exception):
    """Raised when the AST structure does not match the expected
    pythonic tool call format."""

    pass


_JSON_NAME_LITERALS = {
    "null": None,
    "true": True,
    "false": False,
}


def get_parameter_value(val: ast.expr) -> Any:
    """Extract a Python literal value from an AST expression node.

    Handles constants, dicts, lists, and JSON-style name literals
    (null, true, false) that some models produce instead of Python
    literals (None, True, False).

    Raises:
        UnexpectedAstError: If the AST node is not a supported literal type.
    """
    if isinstance(val, ast.Constant):
        return val.value
    elif isinstance(val, ast.Dict):
        if not all(isinstance(k, ast.Constant) for k in val.keys):
            logger.warning(
                "Dict argument keys are not all literals: %s",
                ast.dump(val),
            )
            raise UnexpectedAstError("Dict tool call arguments must have literal keys")
        return {
            k.value: get_parameter_value(v)  # type: ignore
            for k, v in zip(val.keys, val.values)
        }
    elif isinstance(val, ast.List):
        return [get_parameter_value(v) for v in val.elts]
    elif isinstance(val, ast.Name) and val.id in _JSON_NAME_LITERALS:
        return _JSON_NAME_LITERALS[val.id]
    else:
        logger.warning(
            "Unsupported AST node type in tool call arguments: %s",
            ast.dump(val),
        )
        raise UnexpectedAstError("Tool call arguments must be literals")


def handle_single_tool(call: ast.Call) -> ToolCall:
    """Convert a single AST function call node into a ToolCall object.

    Raises:
        UnexpectedAstError: If the call node does not have a simple
            function name (e.g. it's an attribute access or subscript).
    """
    if not isinstance(call.func, ast.Name):
        logger.warning(
            "Tool call has non-simple function name: %s",
            ast.dump(call.func),
        )
        raise UnexpectedAstError("Invalid tool call name")
    function_name = call.func.id
    arguments = {}
    for keyword in call.keywords:
        arguments[keyword.arg] = get_parameter_value(keyword.value)
    return ToolCall(
        type="function",
        function=FunctionCall(
            name=function_name,
            arguments=json.dumps(arguments, ensure_ascii=False),
        ),
    )


def make_valid_python(text: str) -> tuple[str, str] | None:
    """Attempt to close all open brackets/quotes to make partial Python valid.

    Used during streaming to parse incomplete tool call expressions by
    appending the necessary closing characters.

    Returns:
        A tuple of (completed_text, added_suffix) if the text can be
        made valid, or None if the text is too incomplete to complete
        meaningfully (e.g. mid-parameter-name or mid-dict-key).

    Raises:
        UnexpectedAstError: If mismatched brackets or parentheses
            are detected.
    """
    bracket_stack: list[str] = []
    for index, char in enumerate(text):
        if char in {"[", "(", "{"}:
            bracket_stack.append(char)
        elif char == "]":
            if not bracket_stack or bracket_stack.pop() != "[":
                raise UnexpectedAstError("Mismatched square brackets")
        elif char == ")":
            if not bracket_stack or bracket_stack.pop() != "(":
                raise UnexpectedAstError("Mismatched parentheses")
        elif char == "}":
            if not bracket_stack or bracket_stack.pop() != "{":
                raise UnexpectedAstError("Mismatched curly braces")
        elif char in {"'", '"'}:
            if bracket_stack and bracket_stack[-1] == char:
                if index > 0 and text[index - 1] == "\\":
                    pass
                else:
                    bracket_stack.pop()
            elif bracket_stack and bracket_stack[-1] in {"'", '"'}:
                pass
            else:
                bracket_stack.append(char)

    text = text.rstrip()
    if text.endswith("=") or text.endswith(":"):
        return None
    if bracket_stack and bracket_stack[-1] == "{":
        trailing_dict_text = text[: text.rfind("{")]
        num_keys = trailing_dict_text.count(":")
        num_values = trailing_dict_text.count(",")
        if num_keys <= num_values:
            return None
    if bracket_stack and bracket_stack[-1] == "(":
        trailing_params_text = text[: text.rfind("(")]
        num_full_param_names = trailing_params_text.count("=")
        num_full_param_values = trailing_params_text.count(",")
        if num_full_param_names <= num_full_param_values:
            return None
    if text.endswith(","):
        text = text[:-1]
    if (
        bracket_stack
        and bracket_stack[-1] == "["
        and not text.endswith("[")
        and not text.endswith(")")
    ):
        return None

    _CLOSING = {"[": "]", "(": ")", "{": "}", "'": "'", '"': '"'}
    added_text = ""
    for char in reversed(bracket_stack):
        added_text += _CLOSING[char]

    return text + added_text, added_text


def compute_tool_delta(
    previously_sent_args: str,
    new_call: ToolCall,
    index: int,
    withheld_suffix: str,
) -> DeltaToolCall | None:
    """Compute the incremental delta between previously streamed arguments
    and the current tool call state.

    Returns:
        A DeltaToolCall with only the new argument characters, or None
        if there is no difference from what was previously sent.
    """
    new_call_args = new_call.function.arguments
    if withheld_suffix:
        if not new_call_args.endswith(withheld_suffix):
            msg = (
                f"Tool call arguments '{new_call_args}' do not end with "
                f"expected withheld suffix '{withheld_suffix}'"
            )
            logger.error(msg)
            raise ValueError(msg)
        new_call_args = new_call_args[: -len(withheld_suffix)]
    if not previously_sent_args:
        return DeltaToolCall(
            id=new_call.id,
            type="function",
            index=index,
            function=DeltaFunctionCall(
                name=new_call.function.name,
                arguments=new_call_args,
            ),
        )

    arg_diff = new_call_args[len(previously_sent_args) :]
    return (
        DeltaToolCall(
            id=None,
            index=index,
            function=DeltaFunctionCall(arguments=arg_diff),
        )
        if arg_diff
        else None
    )