step3_tool_parser.py 12 KB
Newer Older
Song's avatar
Song committed
1
2
3
4
5
6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import contextlib
import json
from collections.abc import Sequence
7
from typing import Any
Song's avatar
Song committed
8
9
10

import regex as re

11
from vllm.entrypoints.openai.chat_completion.protocol import (
12
    ChatCompletionRequest,
13
14
)
from vllm.entrypoints.openai.engine.protocol import (
15
16
17
18
19
20
21
    DeltaFunctionCall,
    DeltaMessage,
    DeltaToolCall,
    ExtractedToolCallInformation,
    FunctionCall,
    ToolCall,
)
22
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
Song's avatar
Song committed
23
from vllm.logger import init_logger
24
from vllm.tokenizers import TokenizerLike
25
from vllm.tool_parsers.abstract_tool_parser import (
26
    Tool,
27
28
    ToolParser,
)
Song's avatar
Song committed
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from vllm.utils import random_uuid

logger = init_logger(__name__)


class Step3ToolParser(ToolParser):
    """
    Tool parser for a model that uses a specific XML-like format for tool calls.
    This version uses a robust, stateful, cursor-based streaming parser and
    consolidates tool arguments into a single message.
    """

    TOOL_CALLS_BEGIN = "<|tool_calls_begin|>"
    TOOL_CALLS_END = "<|tool_calls_end|>"
    TOOL_CALL_BEGIN = "<|tool_call_begin|>"
    TOOL_CALL_END = "<|tool_call_end|>"
    TOOL_SEP = "<|tool_sep|>"
46
    SPECIAL_TOKENS = [TOOL_CALLS_BEGIN, TOOL_CALLS_END, TOOL_CALL_BEGIN, TOOL_CALL_END]
Song's avatar
Song committed
47

48
49
    def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
        super().__init__(tokenizer, tools)
Song's avatar
Song committed
50
51
52
53
54
        self.position = 0
        # Explicit state flags for robust streaming
        self.tool_block_started = False
        self.tool_block_finished = False

55
56
57
    def adjust_request(
        self, request: ChatCompletionRequest | ResponsesRequest
    ) -> ChatCompletionRequest | ResponsesRequest:
58
        request = super().adjust_request(request)
59
        if request.tools and request.tool_choice != "none":
Song's avatar
Song committed
60
61
62
63
64
            request.skip_special_tokens = False
        return request

    @staticmethod
    def _parse_steptml_invoke(
65
        action_text: str,
66
    ) -> tuple[str | None, dict[str, str] | None]:
67
        func_name_match = re.search(r'<steptml:invoke name="([^"]+)">', action_text)
Song's avatar
Song committed
68
69
70
71
72
73
74
        if not func_name_match:
            return None, None
        func_name = func_name_match.group(1)

        params: dict[str, str] = {}
        param_matches = re.findall(
            r'<steptml:parameter name="([^"]+)">([^<]*)</steptml:parameter>',
75
76
            action_text,
        )
Song's avatar
Song committed
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
        for name, value in param_matches:
            params[name] = value.strip()
        return func_name, params

    def _cast_arguments(
        self,
        func_name: str,
        params: dict[str, Any],
        request: ChatCompletionRequest,
    ) -> dict[str, Any]:
        for tool in request.tools or []:
            if tool.function.name == func_name:
                schema = tool.function.parameters or {}
                properties = schema.get("properties", {})
                for key, value in params.items():
                    if not isinstance(value, str):
                        continue
                    prop = properties.get(key, {})
                    typ = prop.get("type")
                    if typ == "string":
                        params[key] = value.strip()
                    elif typ == "integer":
                        with contextlib.suppress(ValueError):
                            params[key] = int(value)
                    elif typ == "number":
                        with contextlib.suppress(ValueError):
                            params[key] = float(value)
                    elif typ == "boolean":
                        lower_val = value.lower()
106
107
108
109
110
                        params[key] = (
                            lower_val == "true"
                            if lower_val in ("true", "false")
                            else value
                        )
Song's avatar
Song committed
111
                    elif typ == "null":
112
                        params[key] = None if value.lower() == "null" else value
Song's avatar
Song committed
113
114
115
116
117
118
119
120
121
122
123
124
                break
        return params

    def extract_tool_calls_streaming(
        self,
        previous_text: str,
        current_text: str,
        delta_text: str,
        previous_token_ids: Sequence[int],
        current_token_ids: Sequence[int],
        delta_token_ids: Sequence[int],
        request: ChatCompletionRequest,
125
    ) -> DeltaMessage | None:
Song's avatar
Song committed
126
127
128
129
130
        # The main loop processes the stream from the last known position.
        while True:
            if self.position >= len(current_text):
                return None  # We've processed the entire stream.

131
            unprocessed_text = current_text[self.position :]
Song's avatar
Song committed
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146

            # STATE: After all tools are done, all subsequent text is content.
            if self.tool_block_finished:
                self.position = len(current_text)
                return DeltaMessage(content=unprocessed_text)

            # STATE: Before the tool block has started.
            if not self.tool_block_started:
                if unprocessed_text.startswith(self.TOOL_CALLS_BEGIN):
                    self.position += len(self.TOOL_CALLS_BEGIN)
                    self.tool_block_started = True
                    continue  # Token consumed, re-loop.

                start_pos = unprocessed_text.find(self.TOOL_CALLS_BEGIN)
                if start_pos == -1:
147
148
149
150
                    if (
                        self.TOOL_CALLS_BEGIN.startswith(unprocessed_text.strip())
                        and unprocessed_text
                    ):
Song's avatar
Song committed
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
                        return None  # It's a prefix, wait.
                    self.position = len(current_text)
                    return DeltaMessage(content=unprocessed_text)
                else:
                    content = unprocessed_text[:start_pos]
                    self.position += len(content)
                    return DeltaMessage(content=content)

            # STATE: Inside the main tool block.
            offset = len(unprocessed_text) - len(unprocessed_text.lstrip())
            unprocessed_text = unprocessed_text.lstrip()
            self.position += offset

            if unprocessed_text.startswith(self.TOOL_CALLS_END):
                self.position += len(self.TOOL_CALLS_END)
                self.tool_block_finished = True
                self.current_tool_id = -1
                continue

            # Check if we are between tool calls.
171
172
173
            tool_finished = self.current_tool_id != -1 and self.prev_tool_call_arr[
                self.current_tool_id
            ].get("finished")
Song's avatar
Song committed
174
175
176
177
178
179
180
181
182
183
            if self.current_tool_id == -1 or tool_finished:
                if unprocessed_text.startswith(self.TOOL_CALL_BEGIN):
                    self.position += len(self.TOOL_CALL_BEGIN)
                    if self.current_tool_id == -1:
                        self.current_tool_id = 0
                    else:
                        self.current_tool_id += 1
                    self.current_tool_name_sent = False
                    while len(self.prev_tool_call_arr) <= self.current_tool_id:
                        self.prev_tool_call_arr.append({})
184
                    self.prev_tool_call_arr[self.current_tool_id]["finished"] = False
Song's avatar
Song committed
185
186
187
188
189
190
191
                    continue

                if self.TOOL_CALL_BEGIN.startswith(unprocessed_text):
                    return None

            # STATE: Parsing an active tool call.
            if self.current_tool_id != -1 and not self.prev_tool_call_arr[
192
193
                self.current_tool_id
            ].get("finished", False):
Song's avatar
Song committed
194
195
196
197
198
199
                end_tool_pos = unprocessed_text.find(self.TOOL_CALL_END)
                if end_tool_pos == -1:
                    tool_body = unprocessed_text
                else:
                    tool_body = unprocessed_text[:end_tool_pos]

200
                if end_tool_pos == -1 and self.TOOL_CALL_END.startswith(tool_body):
Song's avatar
Song committed
201
202
                    return None

203
                function_name, arguments = self._parse_steptml_invoke(tool_body)
Song's avatar
Song committed
204
205
206
                if not function_name:
                    return None

207
                tool_call_arr = {"name": function_name, "parameters": arguments or {}}
Song's avatar
Song committed
208
209
210
211

                # Send the function name as soon as it's parsed.
                if not self.current_tool_name_sent:
                    self.current_tool_name_sent = True
212
213
214
215
216
217
218
219
220
221
222
                    self.prev_tool_call_arr[self.current_tool_id].update(tool_call_arr)
                    return DeltaMessage(
                        tool_calls=[
                            DeltaToolCall(
                                index=self.current_tool_id,
                                type="function",
                                id=f"chatcmpl-tool-{random_uuid()}",
                                function=DeltaFunctionCall(name=function_name),
                            )
                        ]
                    )
Song's avatar
Song committed
223
224

                # Update our internal state with the latest parsed arguments.
225
226
227
                self.prev_tool_call_arr[self.current_tool_id].update(  # noqa: E501
                    tool_call_arr
                )
Song's avatar
Song committed
228
229
230
231

                # Only send arguments when the tool call is complete.
                if end_tool_pos != -1:
                    self.position += end_tool_pos + len(self.TOOL_CALL_END)
232
                    self.prev_tool_call_arr[self.current_tool_id]["finished"] = True
Song's avatar
Song committed
233
234
235
236

                    final_args = self._cast_arguments(
                        function_name,
                        tool_call_arr.get("parameters", {}),  # type: ignore
237
238
                        request,
                    )
Song's avatar
Song committed
239
                    if final_args:
240
241
242
243
244
245
246
247
248
249
250
                        final_args_json = json.dumps(final_args, ensure_ascii=False)
                        return DeltaMessage(
                            tool_calls=[
                                DeltaToolCall(
                                    index=self.current_tool_id,
                                    function=DeltaFunctionCall(
                                        arguments=final_args_json
                                    ),
                                )
                            ]
                        )
Song's avatar
Song committed
251
252
253
254
255
256
257
258
259
260
261
262

                # If tool is not finished, return None to wait for more tokens.
                return None

            return None

    def extract_tool_calls(
        self,
        model_output: str,
        request: ChatCompletionRequest,
    ) -> ExtractedToolCallInformation:
        if self.TOOL_CALLS_BEGIN not in model_output:
263
264
265
            return ExtractedToolCallInformation(
                tools_called=False, tool_calls=[], content=model_output
            )
Song's avatar
Song committed
266
267
268

        pre_text, rest = model_output.split(self.TOOL_CALLS_BEGIN, 1)
        if self.TOOL_CALLS_END not in rest:
269
270
271
            return ExtractedToolCallInformation(
                tools_called=False, tool_calls=[], content=model_output
            )
Song's avatar
Song committed
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290

        tool_block, post_text = rest.split(self.TOOL_CALLS_END, 1)
        content = (pre_text + post_text).strip()

        tool_calls: list[ToolCall] = []
        call_parts = tool_block.split(self.TOOL_CALL_BEGIN)

        for part in call_parts:
            if not part or self.TOOL_CALL_END not in part:
                continue

            call_content = part.split(self.TOOL_CALL_END, 1)[0]
            if self.TOOL_SEP not in call_content:
                continue

            type_part, invoke_part = call_content.split(self.TOOL_SEP, 1)
            if type_part.strip() != "function":
                continue

291
            function_name, params_dict = self._parse_steptml_invoke(invoke_part)
Song's avatar
Song committed
292
293

            if function_name and params_dict is not None:
294
                params_dict = self._cast_arguments(function_name, params_dict, request)
Song's avatar
Song committed
295
296
                params_str = json.dumps(params_dict, ensure_ascii=False)
                tool_calls.append(
297
298
299
300
                    ToolCall(
                        function=FunctionCall(name=function_name, arguments=params_str)
                    )
                )
Song's avatar
Song committed
301
302
303
304
        if tool_calls:
            return ExtractedToolCallInformation(
                tools_called=True,
                tool_calls=tool_calls,
305
306
307
308
309
                content=content if content else None,
            )
        return ExtractedToolCallInformation(
            tools_called=False, tool_calls=[], content=model_output
        )