pythonic_tool_parser.py 11.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import ast
import json
6
from collections.abc import Sequence
7
from typing import Any
8

9
import regex as re
10
11
from transformers import PreTrainedTokenizerBase

12
import vllm.envs as envs
13
14
15
16
17
18
19
20
21
from vllm.entrypoints.openai.protocol import (
    ChatCompletionRequest,
    DeltaFunctionCall,
    DeltaMessage,
    DeltaToolCall,
    ExtractedToolCallInformation,
    FunctionCall,
    ToolCall,
)
22
23
from vllm.logger import init_logger
from vllm.tool_parsers.abstract_tool_parser import (
24
25
    ToolParser,
)
26
27
28
29
30
31
32
33
34
35
36

logger = init_logger(__name__)


class _UnexpectedAstError(Exception):
    pass


class PythonicToolParser(ToolParser):
    """
    Tool call parser for models that produce tool calls in a pythonic style,
37
    such as Llama 3.2 and Llama 4 models.
38
39
40

    Used when --enable-auto-tool-choice --tool-call-parser pythonic are all set
    """
41

42
43
44
45
46
47
48
49
50
    # TODO(mdepinet): Possible future improvements:
    #   1. Support text + tools separated by either <|python_tag|> or \n\n
    #   2. Support tools outside of a list (or separated by a semicolon).
    #      This depends on item 1 for consistent streaming.
    # Neither of these are necessary for e.g. ToolACE, but both would help make
    # Llama3.2 models more reliable.

    TOOL_CALL_REGEX = re.compile(
        r"\[([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s)?\),\s*)*([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s*)?\)\s*)+\]",
51
52
        re.DOTALL,
    )
53
54
55
56
57
58
59
60
61
62
63
64
65
66

    def __init__(self, tokenizer: PreTrainedTokenizerBase):
        super().__init__(tokenizer)

    # Rename for readability. This is NOT a tool id.
    @property
    def current_tool_index(self) -> int:
        return self.current_tool_id

    @current_tool_index.setter
    def current_tool_index(self, value: int) -> None:
        self.current_tool_id = value

    def extract_tool_calls(
67
68
        self, model_output: str, request: ChatCompletionRequest
    ) -> ExtractedToolCallInformation:
69
70
71
        """
        Extract the tool calls from a complete model response.
        """
72
73
        is_tool_call_pattern = False
        try:
74
75
76
77
78
79
            is_tool_call_pattern = (
                self.TOOL_CALL_REGEX.match(
                    model_output, timeout=envs.VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS
                )
                is not None
            )
80
        except TimeoutError:
81
82
83
84
            logger.warning("Regex timeout occurred when matching tool call pattern.")
            logger.debug(
                "Regex timeout occurred when matching user input: %s", model_output
            )
85
86

        if not is_tool_call_pattern:
87
88
89
            return ExtractedToolCallInformation(
                tools_called=False, tool_calls=[], content=model_output
            )
90
91
92
93
94

        try:
            module = ast.parse(model_output)
            parsed = getattr(module.body[0], "value", None)
            if isinstance(parsed, ast.List) and all(
95
96
                isinstance(e, ast.Call) for e in parsed.elts
            ):
97
98
99
100
101
102
                return ExtractedToolCallInformation(
                    tools_called=True,
                    tool_calls=[
                        _handle_single_tool(e)  # type: ignore
                        for e in parsed.elts
                    ],
103
104
                    content=None,
                )
105
106
            else:
                raise _UnexpectedAstError(
107
108
                    "Tool output must be a list of function calls"
                )
109
110
111
        except Exception:
            logger.exception("Error in extracting tool call from response.")
            # Treat as regular text
112
113
114
            return ExtractedToolCallInformation(
                tools_called=False, tool_calls=[], content=model_output
            )
115
116
117
118
119
120
121
122
123
124

    def extract_tool_calls_streaming(
        self,
        previous_text: str,
        current_text: str,
        delta_text: str,
        previous_token_ids: Sequence[int],
        current_token_ids: Sequence[int],
        delta_token_ids: Sequence[int],
        request: ChatCompletionRequest,
125
    ) -> DeltaMessage | None:
126
127
128
129
130
131
132
133
134
135
136
137
        if not current_text.startswith("["):
            return DeltaMessage(content=delta_text)

        try:
            valid_and_added_text = _make_valid_python(current_text)
            if valid_and_added_text is None:
                return None
            valid_text, added_text = valid_and_added_text

            module = ast.parse(valid_text)
            parsed = getattr(module.body[0], "value", None)
            if not isinstance(parsed, ast.List) or not all(
138
139
                isinstance(e, ast.Call) for e in parsed.elts
            ):
140
                raise _UnexpectedAstError(
141
142
                    "Tool output must be a list of function calls"
                )
143
144
145
146
147
148
149
150
151
152
153
154
155
156
            tool_calls = [
                _handle_single_tool(e)  # type: ignore
                for e in parsed.elts
            ]

            tool_deltas = []
            for index, new_call in enumerate(tool_calls):
                if index < self.current_tool_index:
                    continue

                self.current_tool_index = index
                if len(self.streamed_args_for_tool) == index:
                    self.streamed_args_for_tool.append("")

157
158
159
                new_call_complete = (
                    index < len(tool_calls) - 1 or ")]" not in added_text
                )
160
161
162
                if new_call_complete:
                    self.current_tool_index += 1

163
                withheld_suffix = added_text[:-2] if not new_call_complete else ""
164
165
166
167
168
169
                if not new_call_complete and added_text[-2] == ")":
                    # Function call is incomplete. Withhold the closing bracket.
                    withheld_suffix = withheld_suffix + "}"
                # Strings get single quotes in the model-produced string.
                # JSON requires double quotes.
                withheld_suffix = withheld_suffix.replace("'", '"')
170
171
172
                delta = _compute_tool_delta(
                    self.streamed_args_for_tool[index], new_call, index, withheld_suffix
                )
173
174
175

                if delta is not None:
                    tool_deltas.append(delta)
176
177
178
179
180
                    if (
                        delta.function is not None
                        and delta.function.arguments is not None
                    ):
                        self.streamed_args_for_tool[index] += delta.function.arguments
181
182

            # HACK: serving_chat.py inspects the internal state of tool parsers
183
            # when determining its final streaming delta, automatically
184
185
186
187
188
189
190
191
192
193
194
            # adding autocompleted JSON.
            # These two lines avoid that nonsense while ensuring finish_reason
            # is set to tool_calls when at least one tool is called.
            if tool_deltas and not self.prev_tool_call_arr:
                self.prev_tool_call_arr = [{"arguments": {}}]

            if tool_deltas:
                return DeltaMessage(tool_calls=tool_deltas)
            elif not added_text and self.current_tool_id > 0:
                # Return an empty DeltaMessage once the tool calls are all done
                # so that finish_reason gets set.
195
                return DeltaMessage(content="")
196
197
198
199
200
            else:
                return None
        except Exception:
            logger.exception("Error trying to handle streaming tool call.")
            logger.debug(
201
202
                "Skipping chunk as a result of tool streaming extraction error"
            )
203
204
205
206
207
208
209
210
            return None


def _get_parameter_value(val: ast.expr) -> Any:
    if isinstance(val, ast.Constant):
        return val.value
    elif isinstance(val, ast.Dict):
        if not all(isinstance(k, ast.Constant) for k in val.keys):
211
            raise _UnexpectedAstError("Dict tool call arguments must have literal keys")
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
        return {
            k.value: _get_parameter_value(v)  # type: ignore
            for k, v in zip(val.keys, val.values)
        }
    elif isinstance(val, ast.List):
        return [_get_parameter_value(v) for v in val.elts]
    else:
        raise _UnexpectedAstError("Tool call arguments must be literals")


def _handle_single_tool(call: ast.Call) -> ToolCall:
    if not isinstance(call.func, ast.Name):
        raise _UnexpectedAstError("Invalid tool call name")
    function_name = call.func.id
    arguments = {}
    for keyword in call.keywords:
        arguments[keyword.arg] = _get_parameter_value(keyword.value)
229
230
    return ToolCall(
        type="function",
231
232
233
        function=FunctionCall(
            name=function_name, arguments=json.dumps(arguments, ensure_ascii=False)
        ),
234
    )
235
236


237
def _make_valid_python(text: str) -> tuple[str, str] | None:
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
    bracket_stack = []
    for index, char in enumerate(text):
        if char in {"[", "(", "{"}:
            bracket_stack.append(char)
        elif char == "]":
            if not bracket_stack or bracket_stack.pop() != "[":
                raise _UnexpectedAstError("Mismatched square brackets")
        elif char == ")":
            if not bracket_stack or bracket_stack.pop() != "(":
                raise _UnexpectedAstError("Mismatched parentheses")
        elif char == "}":
            if not bracket_stack or bracket_stack.pop() != "{":
                raise _UnexpectedAstError("Mismatched curly braces")
        elif char in {"'", '"'}:
            if bracket_stack and bracket_stack[-1] == char:
                if index > 0 and text[index - 1] == "\\":
                    # Treat an escaped quote as a regular character
                    pass
                else:
                    bracket_stack.pop()
            elif bracket_stack and bracket_stack[-1] in {"'", '"'}:
                # Double quote within a single quote string or vice versa.
                pass
            else:
                bracket_stack.append(char)

    text = text.rstrip()
    if text.endswith("=") or text.endswith(":"):
        # Since we have no type information for this property/parameter value,
        # we can't fill in a valid value.
        return None
    if bracket_stack and bracket_stack[-1] == "{":
270
        trailing_dict_text = text[: text.rfind("{")]
271
272
273
274
275
        num_keys = trailing_dict_text.count(":")
        num_values = trailing_dict_text.count(",")
        if num_keys <= num_values:
            return None  # Incomplete property name within parameter value
    if bracket_stack and bracket_stack[-1] == "(":
276
        trailing_params_text = text[: text.rfind("(")]
277
278
279
280
281
282
        num_full_param_names = trailing_params_text.count("=")
        num_full_param_values = trailing_params_text.count(",")
        if num_full_param_names <= num_full_param_values:
            return None  # Incomplete parameter name
    if text.endswith(","):
        text = text[:-1]
283
284
285
286
287
288
    if (
        bracket_stack
        and bracket_stack[-1] == "["
        and not text.endswith("[")
        and not text.endswith(")")
    ):
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
        return None  # Incomplete function name

    added_text = ""
    for char in reversed(bracket_stack):
        if char == "[":
            added_text += "]"
        elif char == "(":
            added_text += ")"
        elif char == "{":
            added_text += "}"
        elif char == "'":
            added_text += "'"
        elif char == '"':
            added_text += '"'

    return text + added_text, added_text


307
308
def _compute_tool_delta(
    previously_sent_args: str, new_call: ToolCall, index: int, withheld_suffix: str
309
) -> DeltaToolCall | None:
310
311
312
    new_call_args = new_call.function.arguments
    if withheld_suffix:
        assert new_call_args.endswith(withheld_suffix)
313
        new_call_args = new_call_args[: -len(withheld_suffix)]
314
    if not previously_sent_args:
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
        return DeltaToolCall(
            id=new_call.id,
            type="function",
            index=index,
            function=DeltaFunctionCall(
                name=new_call.function.name,
                arguments=new_call_args,
            ),
        )

    arg_diff = new_call_args[len(previously_sent_args) :]
    return (
        DeltaToolCall(
            id=None, index=index, function=DeltaFunctionCall(arguments=arg_diff)
        )
        if arg_diff
        else None
    )