step3_tool_parser.py 12 KB
Newer Older
Song's avatar
Song committed
1
2
3
4
5
6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import contextlib
import json
from collections.abc import Sequence
7
from typing import Any
Song's avatar
Song committed
8
9
10

import regex as re

11
from vllm.entrypoints.openai.chat_completion.protocol import (
12
    ChatCompletionRequest,
13
14
)
from vllm.entrypoints.openai.engine.protocol import (
15
16
17
18
19
20
21
    DeltaFunctionCall,
    DeltaMessage,
    DeltaToolCall,
    ExtractedToolCallInformation,
    FunctionCall,
    ToolCall,
)
22
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
Song's avatar
Song committed
23
from vllm.logger import init_logger
24
from vllm.tokenizers import TokenizerLike
25
from vllm.tool_parsers.abstract_tool_parser import (
26
    Tool,
27
28
    ToolParser,
)
Song's avatar
Song committed
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from vllm.utils import random_uuid

logger = init_logger(__name__)


class Step3ToolParser(ToolParser):
    """
    Tool parser for a model that uses a specific XML-like format for tool calls.
    This version uses a robust, stateful, cursor-based streaming parser and
    consolidates tool arguments into a single message.
    """

    TOOL_CALLS_BEGIN = "<|tool_calls_begin|>"
    TOOL_CALLS_END = "<|tool_calls_end|>"
    TOOL_CALL_BEGIN = "<|tool_call_begin|>"
    TOOL_CALL_END = "<|tool_call_end|>"
    TOOL_SEP = "<|tool_sep|>"
46
    SPECIAL_TOKENS = [TOOL_CALLS_BEGIN, TOOL_CALLS_END, TOOL_CALL_BEGIN, TOOL_CALL_END]
Song's avatar
Song committed
47

48
49
    def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
        super().__init__(tokenizer, tools)
Song's avatar
Song committed
50
51
52
53
54
        self.position = 0
        # Explicit state flags for robust streaming
        self.tool_block_started = False
        self.tool_block_finished = False

55
56
57
    def adjust_request(
        self, request: ChatCompletionRequest | ResponsesRequest
    ) -> ChatCompletionRequest | ResponsesRequest:
58
        request = super().adjust_request(request)
59
        if request.tools and request.tool_choice != "none":
Song's avatar
Song committed
60
61
62
63
64
            request.skip_special_tokens = False
        return request

    @staticmethod
    def _parse_steptml_invoke(
65
        action_text: str,
66
    ) -> tuple[str | None, dict[str, str] | None]:
67
        func_name_match = re.search(r'<steptml:invoke name="([^"]+)">', action_text)
Song's avatar
Song committed
68
69
70
71
72
73
74
        if not func_name_match:
            return None, None
        func_name = func_name_match.group(1)

        params: dict[str, str] = {}
        param_matches = re.findall(
            r'<steptml:parameter name="([^"]+)">([^<]*)</steptml:parameter>',
75
76
            action_text,
        )
Song's avatar
Song committed
77
78
79
80
81
82
83
84
85
        for name, value in param_matches:
            params[name] = value.strip()
        return func_name, params

    def _cast_arguments(
        self,
        func_name: str,
        params: dict[str, Any],
    ) -> dict[str, Any]:
86
        for tool in self.tools or []:
Song's avatar
Song committed
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
            if tool.function.name == func_name:
                schema = tool.function.parameters or {}
                properties = schema.get("properties", {})
                for key, value in params.items():
                    if not isinstance(value, str):
                        continue
                    prop = properties.get(key, {})
                    typ = prop.get("type")
                    if typ == "string":
                        params[key] = value.strip()
                    elif typ == "integer":
                        with contextlib.suppress(ValueError):
                            params[key] = int(value)
                    elif typ == "number":
                        with contextlib.suppress(ValueError):
                            params[key] = float(value)
                    elif typ == "boolean":
                        lower_val = value.lower()
105
106
107
108
109
                        params[key] = (
                            lower_val == "true"
                            if lower_val in ("true", "false")
                            else value
                        )
Song's avatar
Song committed
110
                    elif typ == "null":
111
                        params[key] = None if value.lower() == "null" else value
Song's avatar
Song committed
112
113
114
115
116
117
118
119
120
121
122
123
                break
        return params

    def extract_tool_calls_streaming(
        self,
        previous_text: str,
        current_text: str,
        delta_text: str,
        previous_token_ids: Sequence[int],
        current_token_ids: Sequence[int],
        delta_token_ids: Sequence[int],
        request: ChatCompletionRequest,
124
    ) -> DeltaMessage | None:
Song's avatar
Song committed
125
126
127
128
129
        # The main loop processes the stream from the last known position.
        while True:
            if self.position >= len(current_text):
                return None  # We've processed the entire stream.

130
            unprocessed_text = current_text[self.position :]
Song's avatar
Song committed
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145

            # STATE: After all tools are done, all subsequent text is content.
            if self.tool_block_finished:
                self.position = len(current_text)
                return DeltaMessage(content=unprocessed_text)

            # STATE: Before the tool block has started.
            if not self.tool_block_started:
                if unprocessed_text.startswith(self.TOOL_CALLS_BEGIN):
                    self.position += len(self.TOOL_CALLS_BEGIN)
                    self.tool_block_started = True
                    continue  # Token consumed, re-loop.

                start_pos = unprocessed_text.find(self.TOOL_CALLS_BEGIN)
                if start_pos == -1:
146
147
148
149
                    if (
                        self.TOOL_CALLS_BEGIN.startswith(unprocessed_text.strip())
                        and unprocessed_text
                    ):
Song's avatar
Song committed
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
                        return None  # It's a prefix, wait.
                    self.position = len(current_text)
                    return DeltaMessage(content=unprocessed_text)
                else:
                    content = unprocessed_text[:start_pos]
                    self.position += len(content)
                    return DeltaMessage(content=content)

            # STATE: Inside the main tool block.
            offset = len(unprocessed_text) - len(unprocessed_text.lstrip())
            unprocessed_text = unprocessed_text.lstrip()
            self.position += offset

            if unprocessed_text.startswith(self.TOOL_CALLS_END):
                self.position += len(self.TOOL_CALLS_END)
                self.tool_block_finished = True
                self.current_tool_id = -1
                continue

            # Check if we are between tool calls.
170
171
172
            tool_finished = self.current_tool_id != -1 and self.prev_tool_call_arr[
                self.current_tool_id
            ].get("finished")
Song's avatar
Song committed
173
174
175
176
177
178
179
180
181
182
            if self.current_tool_id == -1 or tool_finished:
                if unprocessed_text.startswith(self.TOOL_CALL_BEGIN):
                    self.position += len(self.TOOL_CALL_BEGIN)
                    if self.current_tool_id == -1:
                        self.current_tool_id = 0
                    else:
                        self.current_tool_id += 1
                    self.current_tool_name_sent = False
                    while len(self.prev_tool_call_arr) <= self.current_tool_id:
                        self.prev_tool_call_arr.append({})
183
                    self.prev_tool_call_arr[self.current_tool_id]["finished"] = False
Song's avatar
Song committed
184
185
186
187
188
189
190
                    continue

                if self.TOOL_CALL_BEGIN.startswith(unprocessed_text):
                    return None

            # STATE: Parsing an active tool call.
            if self.current_tool_id != -1 and not self.prev_tool_call_arr[
191
192
                self.current_tool_id
            ].get("finished", False):
Song's avatar
Song committed
193
194
195
196
197
198
                end_tool_pos = unprocessed_text.find(self.TOOL_CALL_END)
                if end_tool_pos == -1:
                    tool_body = unprocessed_text
                else:
                    tool_body = unprocessed_text[:end_tool_pos]

199
                if end_tool_pos == -1 and self.TOOL_CALL_END.startswith(tool_body):
Song's avatar
Song committed
200
201
                    return None

202
                function_name, arguments = self._parse_steptml_invoke(tool_body)
Song's avatar
Song committed
203
204
205
                if not function_name:
                    return None

206
                tool_call_arr = {"name": function_name, "parameters": arguments or {}}
Song's avatar
Song committed
207
208
209
210

                # Send the function name as soon as it's parsed.
                if not self.current_tool_name_sent:
                    self.current_tool_name_sent = True
211
212
213
214
215
216
217
218
219
220
221
                    self.prev_tool_call_arr[self.current_tool_id].update(tool_call_arr)
                    return DeltaMessage(
                        tool_calls=[
                            DeltaToolCall(
                                index=self.current_tool_id,
                                type="function",
                                id=f"chatcmpl-tool-{random_uuid()}",
                                function=DeltaFunctionCall(name=function_name),
                            )
                        ]
                    )
Song's avatar
Song committed
222
223

                # Update our internal state with the latest parsed arguments.
224
225
226
                self.prev_tool_call_arr[self.current_tool_id].update(  # noqa: E501
                    tool_call_arr
                )
Song's avatar
Song committed
227
228
229
230

                # Only send arguments when the tool call is complete.
                if end_tool_pos != -1:
                    self.position += end_tool_pos + len(self.TOOL_CALL_END)
231
                    self.prev_tool_call_arr[self.current_tool_id]["finished"] = True
Song's avatar
Song committed
232
233
234
235

                    final_args = self._cast_arguments(
                        function_name,
                        tool_call_arr.get("parameters", {}),  # type: ignore
236
                    )
Song's avatar
Song committed
237
                    if final_args:
238
239
240
241
242
243
244
245
246
247
248
                        final_args_json = json.dumps(final_args, ensure_ascii=False)
                        return DeltaMessage(
                            tool_calls=[
                                DeltaToolCall(
                                    index=self.current_tool_id,
                                    function=DeltaFunctionCall(
                                        arguments=final_args_json
                                    ),
                                )
                            ]
                        )
Song's avatar
Song committed
249
250
251
252
253
254
255
256
257
258
259
260

                # If tool is not finished, return None to wait for more tokens.
                return None

            return None

    def extract_tool_calls(
        self,
        model_output: str,
        request: ChatCompletionRequest,
    ) -> ExtractedToolCallInformation:
        if self.TOOL_CALLS_BEGIN not in model_output:
261
262
263
            return ExtractedToolCallInformation(
                tools_called=False, tool_calls=[], content=model_output
            )
Song's avatar
Song committed
264
265
266

        pre_text, rest = model_output.split(self.TOOL_CALLS_BEGIN, 1)
        if self.TOOL_CALLS_END not in rest:
267
268
269
            return ExtractedToolCallInformation(
                tools_called=False, tool_calls=[], content=model_output
            )
Song's avatar
Song committed
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288

        tool_block, post_text = rest.split(self.TOOL_CALLS_END, 1)
        content = (pre_text + post_text).strip()

        tool_calls: list[ToolCall] = []
        call_parts = tool_block.split(self.TOOL_CALL_BEGIN)

        for part in call_parts:
            if not part or self.TOOL_CALL_END not in part:
                continue

            call_content = part.split(self.TOOL_CALL_END, 1)[0]
            if self.TOOL_SEP not in call_content:
                continue

            type_part, invoke_part = call_content.split(self.TOOL_SEP, 1)
            if type_part.strip() != "function":
                continue

289
            function_name, params_dict = self._parse_steptml_invoke(invoke_part)
Song's avatar
Song committed
290
291

            if function_name and params_dict is not None:
292
                params_dict = self._cast_arguments(function_name, params_dict)
Song's avatar
Song committed
293
294
                params_str = json.dumps(params_dict, ensure_ascii=False)
                tool_calls.append(
295
296
297
298
                    ToolCall(
                        function=FunctionCall(name=function_name, arguments=params_str)
                    )
                )
Song's avatar
Song committed
299
300
301
302
        if tool_calls:
            return ExtractedToolCallInformation(
                tools_called=True,
                tool_calls=tool_calls,
303
304
305
306
307
                content=content if content else None,
            )
        return ExtractedToolCallInformation(
            tools_called=False, tool_calls=[], content=model_output
        )