llama_tool_parser.py 13.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import json
5
from collections.abc import Sequence
6
7

import partial_json_parser
8
import regex as re
9
10
11
from partial_json_parser.core.options import Allow
from transformers import PreTrainedTokenizerBase

12
import vllm.envs as envs
13
from vllm.entrypoints.chat_utils import make_tool_call_id
14
15
16
17
18
19
20
21
22
from vllm.entrypoints.openai.protocol import (
    ChatCompletionRequest,
    DeltaFunctionCall,
    DeltaMessage,
    DeltaToolCall,
    ExtractedToolCallInformation,
    FunctionCall,
    ToolCall,
)
23
24
from vllm.logger import init_logger
from vllm.tool_parsers.abstract_tool_parser import (
25
26
    ToolParser,
)
27
from vllm.tool_parsers.utils import (
28
29
30
31
    find_common_prefix,
    is_complete_json,
    partial_json_loads,
)
32
33
34
35
36
37

logger = init_logger(__name__)


class Llama3JsonToolParser(ToolParser):
    """
38
    Tool call parser for Llama 3.x and 4 models intended for use with the
39
40
    examples/tool_chat_template_llama.jinja template.

41
    Used when --enable-auto-tool-choice --tool-call-parser llama3_json or
42
    llama4_json are set.
43
44
45
46
47
48
49
    """

    def __init__(self, tokenizer: PreTrainedTokenizerBase):
        super().__init__(tokenizer)

        # initialize properties used for state when parsing tool calls in
        # streaming mode
50
        self.prev_tool_call_arr: list[dict] = []
51
52
        self.current_tool_id: int = -1
        self.current_tool_name_sent: bool = False
53
54
55
        self.streamed_args_for_tool: list[
            str
        ] = []  # map what has been streamed for each tool so far to a list
56
        self.bot_token = "<|python_tag|>"
57
58
59
        self.bot_token_id = tokenizer.encode(self.bot_token, add_special_tokens=False)[
            0
        ]
60
61
62
63
        # Simple regex to find opening braces - we'll use JSON decoder for parsing
        # This handles arbitrary nesting depth correctly
        self.tool_call_start_regex = re.compile(r"\{")
        self.json_decoder = json.JSONDecoder()
64

65
    def extract_tool_calls(
66
67
        self, model_output: str, request: ChatCompletionRequest
    ) -> ExtractedToolCallInformation:
68
69
        """
        Extract the tool calls from a complete model response.
70
71
        Only extracts JSON content and ignores any surrounding plain text.
        Supports both single JSON and multiple JSONs separated by semicolons.
72
        """
73
        # Quick check before running regex
74
75
76
77
        if not (self.bot_token in model_output or "{" in model_output):
            return ExtractedToolCallInformation(
                tools_called=False, tool_calls=[], content=model_output
            )
78

79
80
81
82
        # Keep track of the end index of the last parsed JSON object
        # so we don't parse inner brackets
        end_index = -1
        tool_calls: list[ToolCall] = []
83
84

        try:
85
86
87
88
89
90
            for match in self.tool_call_start_regex.finditer(
                model_output, timeout=envs.VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS
            ):
                start_index = match.start()
                # Skip if this brace is inside a previously parsed JSON object
                if start_index <= end_index:
91
                    continue
92
93
94
95
96
97
98
99
100
101
102

                try:
                    obj, json_end_index = self.json_decoder.raw_decode(
                        model_output[start_index:]
                    )
                    end_index = start_index + json_end_index

                    # raise KeyError if missing
                    name = obj["name"]
                    arguments_or_params = (
                        obj["arguments"] if "arguments" in obj else obj["parameters"]
103
104
                    )

105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
                    tool_calls.append(
                        ToolCall(
                            type="function",
                            function=FunctionCall(
                                name=name,
                                # function call args are JSON but as a string
                                arguments=json.dumps(
                                    arguments_or_params, ensure_ascii=False
                                ),
                            ),
                        )
                    )
                except KeyError as e:
                    # Missing required key
                    missing_key = str(e).strip("'\"")
                    logger.exception(
                        "Couldn't extract tool call from JSON response. "
                        "Required key '%s' not present. "
                        "Returning output in content with empty tool calls.",
                        missing_key,
                    )
                    return ExtractedToolCallInformation(
                        tools_called=False, tool_calls=[], content=model_output
                    )
                except Exception:
                    # Any other error during parsing
                    logger.exception(
                        "Error in extracting tool call from response. "
                        "Returning output in content with empty tool calls"
                    )
                    return ExtractedToolCallInformation(
                        tools_called=False, tool_calls=[], content=model_output
                    )
        except TimeoutError:
            logger.warning("Regex timeout occurred when matching tool call pattern.")
            logger.debug(
                "Regex timeout occurred when matching user input: %s", model_output
            )
143
            return ExtractedToolCallInformation(
144
                tools_called=False, tool_calls=[], content=model_output
145
            )
146

147
148
        # If we have valid tool calls, return them normally
        if tool_calls:
149
            return ExtractedToolCallInformation(
150
                tools_called=True, tool_calls=tool_calls, content=None
151
            )
152

153
154
155
156
157
        # No valid tool calls found
        return ExtractedToolCallInformation(
            tools_called=False, tool_calls=[], content=model_output
        )

158
159
160
161
162
163
164
165
    def extract_tool_calls_streaming(
        self,
        previous_text: str,
        current_text: str,
        delta_text: str,
        previous_token_ids: Sequence[int],
        current_token_ids: Sequence[int],
        delta_token_ids: Sequence[int],
166
        request: ChatCompletionRequest,
167
    ) -> DeltaMessage | None:
168
169
170
        if not (
            current_text.startswith(self.bot_token) or current_text.startswith("{")
        ):
171
172
173
174
175
176
            return DeltaMessage(content=delta_text)

        # bit mask flags for partial JSON parsing. If the name hasn't been
        # sent yet, don't allow sending
        # an incomplete string since OpenAI only ever (as far as I have
        # seen) allows sending the entire tool/ function name at once.
177
        flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR
178
179
180
181
182
183
        try:
            tool_call_arr = []
            is_complete = []
            try:
                # depending on the prompt format the Llama model may or may not
                # prefix the output with the <|python_tag|> token
184
185
186
187
188
                start_idx = (
                    len(self.bot_token)
                    if current_text.startswith(self.bot_token)
                    else 0
                )
189
                while start_idx < len(current_text):
190
                    (obj, end_idx) = partial_json_loads(current_text[start_idx:], flags)
191
                    is_complete.append(
192
193
194
                        is_complete_json(current_text[start_idx : start_idx + end_idx])
                    )
                    start_idx += end_idx + len("; ")
195
196
197
                    # depending on the prompt Llama can use
                    # either arguments or parameters
                    if "parameters" in obj:
198
                        assert "arguments" not in obj, (
199
                            "model generated both parameters and arguments"
200
                        )
201
202
203
                        obj["arguments"] = obj["parameters"]
                    tool_call_arr.append(obj)
            except partial_json_parser.core.exceptions.MalformedJSON:
204
                logger.debug("not enough tokens to parse into JSON yet")
205
206
207
                return None

            # select as the current tool call the one we're on the state at
208
209
210
            current_tool_call: dict = (
                tool_call_arr[self.current_tool_id] if len(tool_call_arr) > 0 else {}
            )
211
212
213
214
215
216
217
218

            # case -- if no tokens have been streamed for the tool, e.g.
            #   only the array brackets, stream nothing
            if len(tool_call_arr) == 0:
                return None

            # case: we are starting a new tool in the array
            #   -> array has > 0 length AND length has moved past cursor
219
220
221
            elif (
                len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1
            ):
222
223
224
225
226
227
228
                # if we're moving on to a new call, first make sure we
                # haven't missed anything in the previous one that was
                # auto-generated due to JSON completions, but wasn't
                # streamed to the client yet.
                if self.current_tool_id >= 0:
                    cur_arguments = current_tool_call.get("arguments")
                    if cur_arguments:
229
230
                        cur_args_json = json.dumps(cur_arguments, ensure_ascii=False)
                        sent = len(self.streamed_args_for_tool[self.current_tool_id])
231
232
233
                        argument_diff = cur_args_json[sent:]

                        logger.debug("got arguments diff: %s", argument_diff)
234
235
236
237
238
239
240
241
242
243
244
245
246
                        delta = DeltaMessage(
                            tool_calls=[
                                DeltaToolCall(
                                    index=self.current_tool_id,
                                    function=DeltaFunctionCall(
                                        arguments=argument_diff
                                    ).model_dump(exclude_none=True),
                                )
                            ]
                        )
                        self.streamed_args_for_tool[self.current_tool_id] += (
                            argument_diff
                        )
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
                    else:
                        delta = None
                else:
                    delta = None
                # re-set stuff pertaining to progress in the current tool
                self.current_tool_id = len(tool_call_arr) - 1
                self.current_tool_name_sent = False
                self.streamed_args_for_tool.append("")
                logger.debug("starting on new tool %d", self.current_tool_id)
                return delta

            # if the current tool name hasn't been sent, send if available
            # - otherwise send nothing
            elif not self.current_tool_name_sent:
                function_name = current_tool_call.get("name")
                if function_name:
263
264
265
266
267
268
269
270
271
272
273
274
                    delta = DeltaMessage(
                        tool_calls=[
                            DeltaToolCall(
                                index=self.current_tool_id,
                                type="function",
                                id=make_tool_call_id(),
                                function=DeltaFunctionCall(
                                    name=function_name
                                ).model_dump(exclude_none=True),
                            )
                        ]
                    )
275
276
277
278
279
280
281
282
283
284
285
                    self.current_tool_name_sent = True
                else:
                    delta = None

            # now we know we're on the same tool call and we're streaming
            # arguments
            else:
                cur_arguments = current_tool_call.get("arguments")
                delta = None

                if cur_arguments:
286
287
288
289
290
                    sent = len(self.streamed_args_for_tool[self.current_tool_id])
                    cur_args_json = json.dumps(cur_arguments, ensure_ascii=False)
                    prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get(
                        "arguments"
                    )
291
292
293
294
295

                    argument_diff = None
                    if is_complete[self.current_tool_id]:
                        argument_diff = cur_args_json[sent:]
                    elif prev_arguments:
296
                        prev_args_json = json.dumps(prev_arguments, ensure_ascii=False)
297
                        if cur_args_json != prev_args_json:
298
                            prefix = find_common_prefix(prev_args_json, cur_args_json)
299
300
301
                            argument_diff = prefix[sent:]

                    if argument_diff is not None:
302
303
304
305
306
307
308
309
310
311
312
313
314
                        delta = DeltaMessage(
                            tool_calls=[
                                DeltaToolCall(
                                    index=self.current_tool_id,
                                    function=DeltaFunctionCall(
                                        arguments=argument_diff
                                    ).model_dump(exclude_none=True),
                                )
                            ]
                        )
                        self.streamed_args_for_tool[self.current_tool_id] += (
                            argument_diff
                        )
315
316
317
318

            self.prev_tool_call_arr = tool_call_arr
            return delta

319
320
        except Exception:
            logger.exception("Error trying to handle streaming tool call.")
321
            logger.debug(
322
323
                "Skipping chunk as a result of tool streaming extraction error"
            )
324
            return None