"vllm/vscode:/vscode.git/clone" did not exist on "7d94577138e3d4c7bcfd781337ee1e5a2befa685"
step3_tool_parser.py 11.9 KB
Newer Older
Song's avatar
Song committed
1
2
3
4
5
6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import contextlib
import json
from collections.abc import Sequence
7
from typing import Any
Song's avatar
Song committed
8
9
10

import regex as re

11
from vllm.entrypoints.openai.chat_completion.protocol import (
12
    ChatCompletionRequest,
13
14
)
from vllm.entrypoints.openai.engine.protocol import (
15
16
17
18
19
20
21
    DeltaFunctionCall,
    DeltaMessage,
    DeltaToolCall,
    ExtractedToolCallInformation,
    FunctionCall,
    ToolCall,
)
Song's avatar
Song committed
22
from vllm.logger import init_logger
23
from vllm.tokenizers import TokenizerLike
24
25
26
from vllm.tool_parsers.abstract_tool_parser import (
    ToolParser,
)
Song's avatar
Song committed
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from vllm.utils import random_uuid

logger = init_logger(__name__)


class Step3ToolParser(ToolParser):
    """
    Tool parser for a model that uses a specific XML-like format for tool calls.
    This version uses a robust, stateful, cursor-based streaming parser and
    consolidates tool arguments into a single message.
    """

    TOOL_CALLS_BEGIN = "<|tool_calls_begin|>"
    TOOL_CALLS_END = "<|tool_calls_end|>"
    TOOL_CALL_BEGIN = "<|tool_call_begin|>"
    TOOL_CALL_END = "<|tool_call_end|>"
    TOOL_SEP = "<|tool_sep|>"
44
    SPECIAL_TOKENS = [TOOL_CALLS_BEGIN, TOOL_CALLS_END, TOOL_CALL_BEGIN, TOOL_CALL_END]
Song's avatar
Song committed
45

46
    def __init__(self, tokenizer: TokenizerLike):
Song's avatar
Song committed
47
48
49
50
51
52
        super().__init__(tokenizer)
        self.position = 0
        # Explicit state flags for robust streaming
        self.tool_block_started = False
        self.tool_block_finished = False

53
    def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:
54
        request = super().adjust_request(request)
55
        if request.tools and request.tool_choice != "none":
Song's avatar
Song committed
56
57
58
59
60
            request.skip_special_tokens = False
        return request

    @staticmethod
    def _parse_steptml_invoke(
61
        action_text: str,
62
    ) -> tuple[str | None, dict[str, str] | None]:
63
        func_name_match = re.search(r'<steptml:invoke name="([^"]+)">', action_text)
Song's avatar
Song committed
64
65
66
67
68
69
70
        if not func_name_match:
            return None, None
        func_name = func_name_match.group(1)

        params: dict[str, str] = {}
        param_matches = re.findall(
            r'<steptml:parameter name="([^"]+)">([^<]*)</steptml:parameter>',
71
72
            action_text,
        )
Song's avatar
Song committed
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
        for name, value in param_matches:
            params[name] = value.strip()
        return func_name, params

    def _cast_arguments(
        self,
        func_name: str,
        params: dict[str, Any],
        request: ChatCompletionRequest,
    ) -> dict[str, Any]:
        for tool in request.tools or []:
            if tool.function.name == func_name:
                schema = tool.function.parameters or {}
                properties = schema.get("properties", {})
                for key, value in params.items():
                    if not isinstance(value, str):
                        continue
                    prop = properties.get(key, {})
                    typ = prop.get("type")
                    if typ == "string":
                        params[key] = value.strip()
                    elif typ == "integer":
                        with contextlib.suppress(ValueError):
                            params[key] = int(value)
                    elif typ == "number":
                        with contextlib.suppress(ValueError):
                            params[key] = float(value)
                    elif typ == "boolean":
                        lower_val = value.lower()
102
103
104
105
106
                        params[key] = (
                            lower_val == "true"
                            if lower_val in ("true", "false")
                            else value
                        )
Song's avatar
Song committed
107
                    elif typ == "null":
108
                        params[key] = None if value.lower() == "null" else value
Song's avatar
Song committed
109
110
111
112
113
114
115
116
117
118
119
120
                break
        return params

    def extract_tool_calls_streaming(
        self,
        previous_text: str,
        current_text: str,
        delta_text: str,
        previous_token_ids: Sequence[int],
        current_token_ids: Sequence[int],
        delta_token_ids: Sequence[int],
        request: ChatCompletionRequest,
121
    ) -> DeltaMessage | None:
Song's avatar
Song committed
122
123
124
125
126
        # The main loop processes the stream from the last known position.
        while True:
            if self.position >= len(current_text):
                return None  # We've processed the entire stream.

127
            unprocessed_text = current_text[self.position :]
Song's avatar
Song committed
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142

            # STATE: After all tools are done, all subsequent text is content.
            if self.tool_block_finished:
                self.position = len(current_text)
                return DeltaMessage(content=unprocessed_text)

            # STATE: Before the tool block has started.
            if not self.tool_block_started:
                if unprocessed_text.startswith(self.TOOL_CALLS_BEGIN):
                    self.position += len(self.TOOL_CALLS_BEGIN)
                    self.tool_block_started = True
                    continue  # Token consumed, re-loop.

                start_pos = unprocessed_text.find(self.TOOL_CALLS_BEGIN)
                if start_pos == -1:
143
144
145
146
                    if (
                        self.TOOL_CALLS_BEGIN.startswith(unprocessed_text.strip())
                        and unprocessed_text
                    ):
Song's avatar
Song committed
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
                        return None  # It's a prefix, wait.
                    self.position = len(current_text)
                    return DeltaMessage(content=unprocessed_text)
                else:
                    content = unprocessed_text[:start_pos]
                    self.position += len(content)
                    return DeltaMessage(content=content)

            # STATE: Inside the main tool block.
            offset = len(unprocessed_text) - len(unprocessed_text.lstrip())
            unprocessed_text = unprocessed_text.lstrip()
            self.position += offset

            if unprocessed_text.startswith(self.TOOL_CALLS_END):
                self.position += len(self.TOOL_CALLS_END)
                self.tool_block_finished = True
                self.current_tool_id = -1
                continue

            # Check if we are between tool calls.
167
168
169
            tool_finished = self.current_tool_id != -1 and self.prev_tool_call_arr[
                self.current_tool_id
            ].get("finished")
Song's avatar
Song committed
170
171
172
173
174
175
176
177
178
179
            if self.current_tool_id == -1 or tool_finished:
                if unprocessed_text.startswith(self.TOOL_CALL_BEGIN):
                    self.position += len(self.TOOL_CALL_BEGIN)
                    if self.current_tool_id == -1:
                        self.current_tool_id = 0
                    else:
                        self.current_tool_id += 1
                    self.current_tool_name_sent = False
                    while len(self.prev_tool_call_arr) <= self.current_tool_id:
                        self.prev_tool_call_arr.append({})
180
                    self.prev_tool_call_arr[self.current_tool_id]["finished"] = False
Song's avatar
Song committed
181
182
183
184
185
186
187
                    continue

                if self.TOOL_CALL_BEGIN.startswith(unprocessed_text):
                    return None

            # STATE: Parsing an active tool call.
            if self.current_tool_id != -1 and not self.prev_tool_call_arr[
188
189
                self.current_tool_id
            ].get("finished", False):
Song's avatar
Song committed
190
191
192
193
194
195
                end_tool_pos = unprocessed_text.find(self.TOOL_CALL_END)
                if end_tool_pos == -1:
                    tool_body = unprocessed_text
                else:
                    tool_body = unprocessed_text[:end_tool_pos]

196
                if end_tool_pos == -1 and self.TOOL_CALL_END.startswith(tool_body):
Song's avatar
Song committed
197
198
                    return None

199
                function_name, arguments = self._parse_steptml_invoke(tool_body)
Song's avatar
Song committed
200
201
202
                if not function_name:
                    return None

203
                tool_call_arr = {"name": function_name, "parameters": arguments or {}}
Song's avatar
Song committed
204
205
206
207

                # Send the function name as soon as it's parsed.
                if not self.current_tool_name_sent:
                    self.current_tool_name_sent = True
208
209
210
211
212
213
214
215
216
217
218
                    self.prev_tool_call_arr[self.current_tool_id].update(tool_call_arr)
                    return DeltaMessage(
                        tool_calls=[
                            DeltaToolCall(
                                index=self.current_tool_id,
                                type="function",
                                id=f"chatcmpl-tool-{random_uuid()}",
                                function=DeltaFunctionCall(name=function_name),
                            )
                        ]
                    )
Song's avatar
Song committed
219
220

                # Update our internal state with the latest parsed arguments.
221
222
223
                self.prev_tool_call_arr[self.current_tool_id].update(  # noqa: E501
                    tool_call_arr
                )
Song's avatar
Song committed
224
225
226
227

                # Only send arguments when the tool call is complete.
                if end_tool_pos != -1:
                    self.position += end_tool_pos + len(self.TOOL_CALL_END)
228
                    self.prev_tool_call_arr[self.current_tool_id]["finished"] = True
Song's avatar
Song committed
229
230
231
232

                    final_args = self._cast_arguments(
                        function_name,
                        tool_call_arr.get("parameters", {}),  # type: ignore
233
234
                        request,
                    )
Song's avatar
Song committed
235
                    if final_args:
236
237
238
239
240
241
242
243
244
245
246
                        final_args_json = json.dumps(final_args, ensure_ascii=False)
                        return DeltaMessage(
                            tool_calls=[
                                DeltaToolCall(
                                    index=self.current_tool_id,
                                    function=DeltaFunctionCall(
                                        arguments=final_args_json
                                    ),
                                )
                            ]
                        )
Song's avatar
Song committed
247
248
249
250
251
252
253
254
255
256
257
258

                # If tool is not finished, return None to wait for more tokens.
                return None

            return None

    def extract_tool_calls(
        self,
        model_output: str,
        request: ChatCompletionRequest,
    ) -> ExtractedToolCallInformation:
        if self.TOOL_CALLS_BEGIN not in model_output:
259
260
261
            return ExtractedToolCallInformation(
                tools_called=False, tool_calls=[], content=model_output
            )
Song's avatar
Song committed
262
263
264

        pre_text, rest = model_output.split(self.TOOL_CALLS_BEGIN, 1)
        if self.TOOL_CALLS_END not in rest:
265
266
267
            return ExtractedToolCallInformation(
                tools_called=False, tool_calls=[], content=model_output
            )
Song's avatar
Song committed
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286

        tool_block, post_text = rest.split(self.TOOL_CALLS_END, 1)
        content = (pre_text + post_text).strip()

        tool_calls: list[ToolCall] = []
        call_parts = tool_block.split(self.TOOL_CALL_BEGIN)

        for part in call_parts:
            if not part or self.TOOL_CALL_END not in part:
                continue

            call_content = part.split(self.TOOL_CALL_END, 1)[0]
            if self.TOOL_SEP not in call_content:
                continue

            type_part, invoke_part = call_content.split(self.TOOL_SEP, 1)
            if type_part.strip() != "function":
                continue

287
            function_name, params_dict = self._parse_steptml_invoke(invoke_part)
Song's avatar
Song committed
288
289

            if function_name and params_dict is not None:
290
                params_dict = self._cast_arguments(function_name, params_dict, request)
Song's avatar
Song committed
291
292
                params_str = json.dumps(params_dict, ensure_ascii=False)
                tool_calls.append(
293
294
295
296
                    ToolCall(
                        function=FunctionCall(name=function_name, arguments=params_str)
                    )
                )
Song's avatar
Song committed
297
298
299
300
        if tool_calls:
            return ExtractedToolCallInformation(
                tools_called=True,
                tool_calls=tool_calls,
301
302
303
304
305
                content=content if content else None,
            )
        return ExtractedToolCallInformation(
            tools_called=False, tool_calls=[], content=model_output
        )