step3_tool_parser.py 11.8 KB
Newer Older
Song's avatar
Song committed
1
2
3
4
5
6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import contextlib
import json
from collections.abc import Sequence
7
from typing import Any
Song's avatar
Song committed
8
9
10

import regex as re

11
12
13
14
15
16
17
18
19
from vllm.entrypoints.openai.protocol import (
    ChatCompletionRequest,
    DeltaFunctionCall,
    DeltaMessage,
    DeltaToolCall,
    ExtractedToolCallInformation,
    FunctionCall,
    ToolCall,
)
Song's avatar
Song committed
20
from vllm.logger import init_logger
21
from vllm.tokenizers import TokenizerLike
22
23
24
from vllm.tool_parsers.abstract_tool_parser import (
    ToolParser,
)
Song's avatar
Song committed
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from vllm.utils import random_uuid

logger = init_logger(__name__)


class Step3ToolParser(ToolParser):
    """
    Tool parser for a model that uses a specific XML-like format for tool calls.
    This version uses a robust, stateful, cursor-based streaming parser and
    consolidates tool arguments into a single message.
    """

    TOOL_CALLS_BEGIN = "<|tool_calls_begin|>"
    TOOL_CALLS_END = "<|tool_calls_end|>"
    TOOL_CALL_BEGIN = "<|tool_call_begin|>"
    TOOL_CALL_END = "<|tool_call_end|>"
    TOOL_SEP = "<|tool_sep|>"
42
    SPECIAL_TOKENS = [TOOL_CALLS_BEGIN, TOOL_CALLS_END, TOOL_CALL_BEGIN, TOOL_CALL_END]
Song's avatar
Song committed
43

44
    def __init__(self, tokenizer: TokenizerLike):
Song's avatar
Song committed
45
46
47
48
49
50
        super().__init__(tokenizer)
        self.position = 0
        # Explicit state flags for robust streaming
        self.tool_block_started = False
        self.tool_block_finished = False

51
    def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:
52
        request = super().adjust_request(request)
53
        if request.tools and request.tool_choice != "none":
Song's avatar
Song committed
54
55
56
57
58
            request.skip_special_tokens = False
        return request

    @staticmethod
    def _parse_steptml_invoke(
59
        action_text: str,
60
    ) -> tuple[str | None, dict[str, str] | None]:
61
        func_name_match = re.search(r'<steptml:invoke name="([^"]+)">', action_text)
Song's avatar
Song committed
62
63
64
65
66
67
68
        if not func_name_match:
            return None, None
        func_name = func_name_match.group(1)

        params: dict[str, str] = {}
        param_matches = re.findall(
            r'<steptml:parameter name="([^"]+)">([^<]*)</steptml:parameter>',
69
70
            action_text,
        )
Song's avatar
Song committed
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
        for name, value in param_matches:
            params[name] = value.strip()
        return func_name, params

    def _cast_arguments(
        self,
        func_name: str,
        params: dict[str, Any],
        request: ChatCompletionRequest,
    ) -> dict[str, Any]:
        for tool in request.tools or []:
            if tool.function.name == func_name:
                schema = tool.function.parameters or {}
                properties = schema.get("properties", {})
                for key, value in params.items():
                    if not isinstance(value, str):
                        continue
                    prop = properties.get(key, {})
                    typ = prop.get("type")
                    if typ == "string":
                        params[key] = value.strip()
                    elif typ == "integer":
                        with contextlib.suppress(ValueError):
                            params[key] = int(value)
                    elif typ == "number":
                        with contextlib.suppress(ValueError):
                            params[key] = float(value)
                    elif typ == "boolean":
                        lower_val = value.lower()
100
101
102
103
104
                        params[key] = (
                            lower_val == "true"
                            if lower_val in ("true", "false")
                            else value
                        )
Song's avatar
Song committed
105
                    elif typ == "null":
106
                        params[key] = None if value.lower() == "null" else value
Song's avatar
Song committed
107
108
109
110
111
112
113
114
115
116
117
118
                break
        return params

    def extract_tool_calls_streaming(
        self,
        previous_text: str,
        current_text: str,
        delta_text: str,
        previous_token_ids: Sequence[int],
        current_token_ids: Sequence[int],
        delta_token_ids: Sequence[int],
        request: ChatCompletionRequest,
119
    ) -> DeltaMessage | None:
Song's avatar
Song committed
120
121
122
123
124
        # The main loop processes the stream from the last known position.
        while True:
            if self.position >= len(current_text):
                return None  # We've processed the entire stream.

125
            unprocessed_text = current_text[self.position :]
Song's avatar
Song committed
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140

            # STATE: After all tools are done, all subsequent text is content.
            if self.tool_block_finished:
                self.position = len(current_text)
                return DeltaMessage(content=unprocessed_text)

            # STATE: Before the tool block has started.
            if not self.tool_block_started:
                if unprocessed_text.startswith(self.TOOL_CALLS_BEGIN):
                    self.position += len(self.TOOL_CALLS_BEGIN)
                    self.tool_block_started = True
                    continue  # Token consumed, re-loop.

                start_pos = unprocessed_text.find(self.TOOL_CALLS_BEGIN)
                if start_pos == -1:
141
142
143
144
                    if (
                        self.TOOL_CALLS_BEGIN.startswith(unprocessed_text.strip())
                        and unprocessed_text
                    ):
Song's avatar
Song committed
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
                        return None  # It's a prefix, wait.
                    self.position = len(current_text)
                    return DeltaMessage(content=unprocessed_text)
                else:
                    content = unprocessed_text[:start_pos]
                    self.position += len(content)
                    return DeltaMessage(content=content)

            # STATE: Inside the main tool block.
            offset = len(unprocessed_text) - len(unprocessed_text.lstrip())
            unprocessed_text = unprocessed_text.lstrip()
            self.position += offset

            if unprocessed_text.startswith(self.TOOL_CALLS_END):
                self.position += len(self.TOOL_CALLS_END)
                self.tool_block_finished = True
                self.current_tool_id = -1
                continue

            # Check if we are between tool calls.
165
166
167
            tool_finished = self.current_tool_id != -1 and self.prev_tool_call_arr[
                self.current_tool_id
            ].get("finished")
Song's avatar
Song committed
168
169
170
171
172
173
174
175
176
177
            if self.current_tool_id == -1 or tool_finished:
                if unprocessed_text.startswith(self.TOOL_CALL_BEGIN):
                    self.position += len(self.TOOL_CALL_BEGIN)
                    if self.current_tool_id == -1:
                        self.current_tool_id = 0
                    else:
                        self.current_tool_id += 1
                    self.current_tool_name_sent = False
                    while len(self.prev_tool_call_arr) <= self.current_tool_id:
                        self.prev_tool_call_arr.append({})
178
                    self.prev_tool_call_arr[self.current_tool_id]["finished"] = False
Song's avatar
Song committed
179
180
181
182
183
184
185
                    continue

                if self.TOOL_CALL_BEGIN.startswith(unprocessed_text):
                    return None

            # STATE: Parsing an active tool call.
            if self.current_tool_id != -1 and not self.prev_tool_call_arr[
186
187
                self.current_tool_id
            ].get("finished", False):
Song's avatar
Song committed
188
189
190
191
192
193
                end_tool_pos = unprocessed_text.find(self.TOOL_CALL_END)
                if end_tool_pos == -1:
                    tool_body = unprocessed_text
                else:
                    tool_body = unprocessed_text[:end_tool_pos]

194
                if end_tool_pos == -1 and self.TOOL_CALL_END.startswith(tool_body):
Song's avatar
Song committed
195
196
                    return None

197
                function_name, arguments = self._parse_steptml_invoke(tool_body)
Song's avatar
Song committed
198
199
200
                if not function_name:
                    return None

201
                tool_call_arr = {"name": function_name, "parameters": arguments or {}}
Song's avatar
Song committed
202
203
204
205

                # Send the function name as soon as it's parsed.
                if not self.current_tool_name_sent:
                    self.current_tool_name_sent = True
206
207
208
209
210
211
212
213
214
215
216
                    self.prev_tool_call_arr[self.current_tool_id].update(tool_call_arr)
                    return DeltaMessage(
                        tool_calls=[
                            DeltaToolCall(
                                index=self.current_tool_id,
                                type="function",
                                id=f"chatcmpl-tool-{random_uuid()}",
                                function=DeltaFunctionCall(name=function_name),
                            )
                        ]
                    )
Song's avatar
Song committed
217
218

                # Update our internal state with the latest parsed arguments.
219
220
221
                self.prev_tool_call_arr[self.current_tool_id].update(  # noqa: E501
                    tool_call_arr
                )
Song's avatar
Song committed
222
223
224
225

                # Only send arguments when the tool call is complete.
                if end_tool_pos != -1:
                    self.position += end_tool_pos + len(self.TOOL_CALL_END)
226
                    self.prev_tool_call_arr[self.current_tool_id]["finished"] = True
Song's avatar
Song committed
227
228
229
230

                    final_args = self._cast_arguments(
                        function_name,
                        tool_call_arr.get("parameters", {}),  # type: ignore
231
232
                        request,
                    )
Song's avatar
Song committed
233
                    if final_args:
234
235
236
237
238
239
240
241
242
243
244
                        final_args_json = json.dumps(final_args, ensure_ascii=False)
                        return DeltaMessage(
                            tool_calls=[
                                DeltaToolCall(
                                    index=self.current_tool_id,
                                    function=DeltaFunctionCall(
                                        arguments=final_args_json
                                    ),
                                )
                            ]
                        )
Song's avatar
Song committed
245
246
247
248
249
250
251
252
253
254
255
256

                # If tool is not finished, return None to wait for more tokens.
                return None

            return None

    def extract_tool_calls(
        self,
        model_output: str,
        request: ChatCompletionRequest,
    ) -> ExtractedToolCallInformation:
        if self.TOOL_CALLS_BEGIN not in model_output:
257
258
259
            return ExtractedToolCallInformation(
                tools_called=False, tool_calls=[], content=model_output
            )
Song's avatar
Song committed
260
261
262

        pre_text, rest = model_output.split(self.TOOL_CALLS_BEGIN, 1)
        if self.TOOL_CALLS_END not in rest:
263
264
265
            return ExtractedToolCallInformation(
                tools_called=False, tool_calls=[], content=model_output
            )
Song's avatar
Song committed
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284

        tool_block, post_text = rest.split(self.TOOL_CALLS_END, 1)
        content = (pre_text + post_text).strip()

        tool_calls: list[ToolCall] = []
        call_parts = tool_block.split(self.TOOL_CALL_BEGIN)

        for part in call_parts:
            if not part or self.TOOL_CALL_END not in part:
                continue

            call_content = part.split(self.TOOL_CALL_END, 1)[0]
            if self.TOOL_SEP not in call_content:
                continue

            type_part, invoke_part = call_content.split(self.TOOL_SEP, 1)
            if type_part.strip() != "function":
                continue

285
            function_name, params_dict = self._parse_steptml_invoke(invoke_part)
Song's avatar
Song committed
286
287

            if function_name and params_dict is not None:
288
                params_dict = self._cast_arguments(function_name, params_dict, request)
Song's avatar
Song committed
289
290
                params_str = json.dumps(params_dict, ensure_ascii=False)
                tool_calls.append(
291
292
293
294
                    ToolCall(
                        function=FunctionCall(name=function_name, arguments=params_str)
                    )
                )
Song's avatar
Song committed
295
296
297
298
        if tool_calls:
            return ExtractedToolCallInformation(
                tools_called=True,
                tool_calls=tool_calls,
299
300
301
302
303
                content=content if content else None,
            )
        return ExtractedToolCallInformation(
            tools_called=False, tool_calls=[], content=model_output
        )