llama_tool_parser.py 13.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import json
5
from collections.abc import Sequence
6
7

import partial_json_parser
8
import regex as re
9
10
11
from partial_json_parser.core.options import Allow
from transformers import PreTrainedTokenizerBase

12
import vllm.envs as envs
13
from vllm.entrypoints.chat_utils import make_tool_call_id
14
from vllm.entrypoints.openai.chat_completion.protocol import (
15
    ChatCompletionRequest,
16
17
)
from vllm.entrypoints.openai.engine.protocol import (
18
19
20
21
22
23
24
    DeltaFunctionCall,
    DeltaMessage,
    DeltaToolCall,
    ExtractedToolCallInformation,
    FunctionCall,
    ToolCall,
)
25
26
from vllm.logger import init_logger
from vllm.tool_parsers.abstract_tool_parser import (
27
28
    ToolParser,
)
29
from vllm.tool_parsers.utils import (
30
31
32
33
    find_common_prefix,
    is_complete_json,
    partial_json_loads,
)
34
35
36
37
38
39

logger = init_logger(__name__)


class Llama3JsonToolParser(ToolParser):
    """
40
    Tool call parser for Llama 3.x and 4 models intended for use with the
41
42
    examples/tool_chat_template_llama.jinja template.

43
    Used when --enable-auto-tool-choice --tool-call-parser llama3_json or
44
    llama4_json are set.
45
46
47
48
49
50
51
    """

    def __init__(self, tokenizer: PreTrainedTokenizerBase):
        super().__init__(tokenizer)

        # initialize properties used for state when parsing tool calls in
        # streaming mode
52
        self.prev_tool_call_arr: list[dict] = []
53
54
        self.current_tool_id: int = -1
        self.current_tool_name_sent: bool = False
55
56
57
        self.streamed_args_for_tool: list[
            str
        ] = []  # map what has been streamed for each tool so far to a list
58
        self.bot_token = "<|python_tag|>"
59
60
61
        self.bot_token_id = tokenizer.encode(self.bot_token, add_special_tokens=False)[
            0
        ]
62
63
64
65
        # Simple regex to find opening braces - we'll use JSON decoder for parsing
        # This handles arbitrary nesting depth correctly
        self.tool_call_start_regex = re.compile(r"\{")
        self.json_decoder = json.JSONDecoder()
66

67
    def extract_tool_calls(
68
69
        self, model_output: str, request: ChatCompletionRequest
    ) -> ExtractedToolCallInformation:
70
71
        """
        Extract the tool calls from a complete model response.
72
73
        Only extracts JSON content and ignores any surrounding plain text.
        Supports both single JSON and multiple JSONs separated by semicolons.
74
        """
75
        # Quick check before running regex
76
77
78
79
        if not (self.bot_token in model_output or "{" in model_output):
            return ExtractedToolCallInformation(
                tools_called=False, tool_calls=[], content=model_output
            )
80

81
82
83
84
        # Keep track of the end index of the last parsed JSON object
        # so we don't parse inner brackets
        end_index = -1
        tool_calls: list[ToolCall] = []
85
86

        try:
87
88
89
90
91
92
            for match in self.tool_call_start_regex.finditer(
                model_output, timeout=envs.VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS
            ):
                start_index = match.start()
                # Skip if this brace is inside a previously parsed JSON object
                if start_index <= end_index:
93
                    continue
94
95
96
97
98
99
100
101
102
103
104

                try:
                    obj, json_end_index = self.json_decoder.raw_decode(
                        model_output[start_index:]
                    )
                    end_index = start_index + json_end_index

                    # raise KeyError if missing
                    name = obj["name"]
                    arguments_or_params = (
                        obj["arguments"] if "arguments" in obj else obj["parameters"]
105
106
                    )

107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
                    tool_calls.append(
                        ToolCall(
                            type="function",
                            function=FunctionCall(
                                name=name,
                                # function call args are JSON but as a string
                                arguments=json.dumps(
                                    arguments_or_params, ensure_ascii=False
                                ),
                            ),
                        )
                    )
                except KeyError as e:
                    # Missing required key
                    missing_key = str(e).strip("'\"")
                    logger.exception(
                        "Couldn't extract tool call from JSON response. "
                        "Required key '%s' not present. "
                        "Returning output in content with empty tool calls.",
                        missing_key,
                    )
                    return ExtractedToolCallInformation(
                        tools_called=False, tool_calls=[], content=model_output
                    )
                except Exception:
                    # Any other error during parsing
                    logger.exception(
                        "Error in extracting tool call from response. "
                        "Returning output in content with empty tool calls"
                    )
                    return ExtractedToolCallInformation(
                        tools_called=False, tool_calls=[], content=model_output
                    )
        except TimeoutError:
            logger.warning("Regex timeout occurred when matching tool call pattern.")
            logger.debug(
                "Regex timeout occurred when matching user input: %s", model_output
            )
145
            return ExtractedToolCallInformation(
146
                tools_called=False, tool_calls=[], content=model_output
147
            )
148

149
150
        # If we have valid tool calls, return them normally
        if tool_calls:
151
            return ExtractedToolCallInformation(
152
                tools_called=True, tool_calls=tool_calls, content=None
153
            )
154

155
156
157
158
159
        # No valid tool calls found
        return ExtractedToolCallInformation(
            tools_called=False, tool_calls=[], content=model_output
        )

160
161
162
163
164
165
166
167
    def extract_tool_calls_streaming(
        self,
        previous_text: str,
        current_text: str,
        delta_text: str,
        previous_token_ids: Sequence[int],
        current_token_ids: Sequence[int],
        delta_token_ids: Sequence[int],
168
        request: ChatCompletionRequest,
169
    ) -> DeltaMessage | None:
170
171
172
        if not (
            current_text.startswith(self.bot_token) or current_text.startswith("{")
        ):
173
174
175
176
177
178
            return DeltaMessage(content=delta_text)

        # bit mask flags for partial JSON parsing. If the name hasn't been
        # sent yet, don't allow sending
        # an incomplete string since OpenAI only ever (as far as I have
        # seen) allows sending the entire tool/ function name at once.
179
        flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR
180
181
182
183
184
185
        try:
            tool_call_arr = []
            is_complete = []
            try:
                # depending on the prompt format the Llama model may or may not
                # prefix the output with the <|python_tag|> token
186
187
188
189
190
                start_idx = (
                    len(self.bot_token)
                    if current_text.startswith(self.bot_token)
                    else 0
                )
191
                while start_idx < len(current_text):
192
                    (obj, end_idx) = partial_json_loads(current_text[start_idx:], flags)
193
                    is_complete.append(
194
195
196
                        is_complete_json(current_text[start_idx : start_idx + end_idx])
                    )
                    start_idx += end_idx + len("; ")
197
198
199
                    # depending on the prompt Llama can use
                    # either arguments or parameters
                    if "parameters" in obj:
200
                        assert "arguments" not in obj, (
201
                            "model generated both parameters and arguments"
202
                        )
203
204
205
                        obj["arguments"] = obj["parameters"]
                    tool_call_arr.append(obj)
            except partial_json_parser.core.exceptions.MalformedJSON:
206
                logger.debug("not enough tokens to parse into JSON yet")
207
208
209
                return None

            # select as the current tool call the one we're on the state at
210
211
212
            current_tool_call: dict = (
                tool_call_arr[self.current_tool_id] if len(tool_call_arr) > 0 else {}
            )
213
214
215
216
217
218
219
220

            # case -- if no tokens have been streamed for the tool, e.g.
            #   only the array brackets, stream nothing
            if len(tool_call_arr) == 0:
                return None

            # case: we are starting a new tool in the array
            #   -> array has > 0 length AND length has moved past cursor
221
222
223
            elif (
                len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1
            ):
224
225
226
227
228
229
230
                # if we're moving on to a new call, first make sure we
                # haven't missed anything in the previous one that was
                # auto-generated due to JSON completions, but wasn't
                # streamed to the client yet.
                if self.current_tool_id >= 0:
                    cur_arguments = current_tool_call.get("arguments")
                    if cur_arguments:
231
232
                        cur_args_json = json.dumps(cur_arguments, ensure_ascii=False)
                        sent = len(self.streamed_args_for_tool[self.current_tool_id])
233
234
235
                        argument_diff = cur_args_json[sent:]

                        logger.debug("got arguments diff: %s", argument_diff)
236
237
238
239
240
241
242
243
244
245
246
247
248
                        delta = DeltaMessage(
                            tool_calls=[
                                DeltaToolCall(
                                    index=self.current_tool_id,
                                    function=DeltaFunctionCall(
                                        arguments=argument_diff
                                    ).model_dump(exclude_none=True),
                                )
                            ]
                        )
                        self.streamed_args_for_tool[self.current_tool_id] += (
                            argument_diff
                        )
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
                    else:
                        delta = None
                else:
                    delta = None
                # re-set stuff pertaining to progress in the current tool
                self.current_tool_id = len(tool_call_arr) - 1
                self.current_tool_name_sent = False
                self.streamed_args_for_tool.append("")
                logger.debug("starting on new tool %d", self.current_tool_id)
                return delta

            # if the current tool name hasn't been sent, send if available
            # - otherwise send nothing
            elif not self.current_tool_name_sent:
                function_name = current_tool_call.get("name")
                if function_name:
265
266
267
268
269
270
271
272
273
274
275
276
                    delta = DeltaMessage(
                        tool_calls=[
                            DeltaToolCall(
                                index=self.current_tool_id,
                                type="function",
                                id=make_tool_call_id(),
                                function=DeltaFunctionCall(
                                    name=function_name
                                ).model_dump(exclude_none=True),
                            )
                        ]
                    )
277
278
279
280
281
282
283
284
285
286
287
                    self.current_tool_name_sent = True
                else:
                    delta = None

            # now we know we're on the same tool call and we're streaming
            # arguments
            else:
                cur_arguments = current_tool_call.get("arguments")
                delta = None

                if cur_arguments:
288
289
290
291
292
                    sent = len(self.streamed_args_for_tool[self.current_tool_id])
                    cur_args_json = json.dumps(cur_arguments, ensure_ascii=False)
                    prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get(
                        "arguments"
                    )
293
294
295
296
297

                    argument_diff = None
                    if is_complete[self.current_tool_id]:
                        argument_diff = cur_args_json[sent:]
                    elif prev_arguments:
298
                        prev_args_json = json.dumps(prev_arguments, ensure_ascii=False)
299
                        if cur_args_json != prev_args_json:
300
                            prefix = find_common_prefix(prev_args_json, cur_args_json)
301
302
303
                            argument_diff = prefix[sent:]

                    if argument_diff is not None:
304
305
306
307
308
309
310
311
312
313
314
315
316
                        delta = DeltaMessage(
                            tool_calls=[
                                DeltaToolCall(
                                    index=self.current_tool_id,
                                    function=DeltaFunctionCall(
                                        arguments=argument_diff
                                    ).model_dump(exclude_none=True),
                                )
                            ]
                        )
                        self.streamed_args_for_tool[self.current_tool_id] += (
                            argument_diff
                        )
317
318
319
320

            self.prev_tool_call_arr = tool_call_arr
            return delta

321
322
        except Exception:
            logger.exception("Error trying to handle streaming tool call.")
323
            logger.debug(
324
325
                "Skipping chunk as a result of tool streaming extraction error"
            )
326
            return None