internlm2_tool_parser.py 9.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import json
5
from collections.abc import Sequence
6
7
8
9

import partial_json_parser
from partial_json_parser.core.options import Allow

10
from vllm.entrypoints.chat_utils import make_tool_call_id
11
from vllm.entrypoints.openai.chat_completion.protocol import (
12
    ChatCompletionRequest,
13
14
)
from vllm.entrypoints.openai.engine.protocol import (
15
16
17
18
19
20
21
    DeltaFunctionCall,
    DeltaMessage,
    DeltaToolCall,
    ExtractedToolCallInformation,
    FunctionCall,
    ToolCall,
)
22
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
23
from vllm.logger import init_logger
24
from vllm.tokenizers import TokenizerLike
25
from vllm.tool_parsers.abstract_tool_parser import (
26
    Tool,
27
28
29
    ToolParser,
)
from vllm.tool_parsers.utils import extract_intermediate_diff
30
31
32
33
34

logger = init_logger(__name__)


class Internlm2ToolParser(ToolParser):
35
36
    def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
        super().__init__(tokenizer, tools)
37
38
        self.position = 0

39
40
41
    def adjust_request(
        self, request: ChatCompletionRequest | ResponsesRequest
    ) -> ChatCompletionRequest | ResponsesRequest:
42
        request = super().adjust_request(request)
43
        if request.tools and request.tool_choice != "none":
44
            # do not skip special tokens because internlm use the special
45
            # tokens to indicate the start and end of the tool calls
46
47
48
49
            # information.
            request.skip_special_tokens = False
        return request

50
    def get_arguments(self, obj):
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
        if "parameters" in obj:
            return obj.get("parameters")
        elif "arguments" in obj:
            return obj.get("arguments")
        return None

    def extract_tool_calls_streaming(
        self,
        previous_text: str,
        current_text: str,
        delta_text: str,
        previous_token_ids: Sequence[int],
        current_token_ids: Sequence[int],
        delta_token_ids: Sequence[int],
        request: ChatCompletionRequest,
66
    ) -> DeltaMessage | None:
67
        if "<|action_start|>" not in current_text:
68
69
            self.position = len(current_text)
            return DeltaMessage(content=delta_text)
70
        # if the tool call is sent, return an empty delta message
71
        # to make sure the finish_reason will be sent correctly.
72
        if self.current_tool_id > 0:
73
            return DeltaMessage(content="")
74
75

        last_pos = self.position
76
        if "<|action_start|><|plugin|>" not in current_text[last_pos:]:
77
78
79
            return None

        new_delta = current_text[last_pos:]
80
        text, action = new_delta.split("<|action_start|><|plugin|>")
81
82
83
84
85
86

        if len(text) > 0:
            self.position = self.position + len(text)
            return DeltaMessage(content=text)

        action = action.strip()
87
        action = action.split("<|action_end|>".strip())[0]
88
89
90
91
92

        # bit mask flags for partial JSON parsing. If the name hasn't been
        # sent yet, don't allow sending
        # an incomplete string since OpenAI only ever (as far as I have
        # seen) allows sending the entire tool/ function name at once.
93
        flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR
94
95
96
97

        try:
            parsable_arr = action

co63oc's avatar
co63oc committed
98
            # tool calls are generated in an object in internlm2
99
100
            # it's not support parallel tool calls
            try:
101
                tool_call_arr: dict = partial_json_parser.loads(parsable_arr, flags)
102
            except partial_json_parser.core.exceptions.MalformedJSON:
103
                logger.debug("not enough tokens to parse into JSON yet")
104
105
106
107
108
109
110
111
                return None

            # if the current tool name hasn't been sent, send if available
            # - otherwise send nothing
            if not self.current_tool_name_sent:
                function_name = tool_call_arr.get("name")
                if function_name:
                    self.current_tool_id = self.current_tool_id + 1
112
113
114
115
116
117
118
119
120
121
122
123
                    delta = DeltaMessage(
                        tool_calls=[
                            DeltaToolCall(
                                index=self.current_tool_id,
                                type="function",
                                id=make_tool_call_id(),
                                function=DeltaFunctionCall(
                                    name=function_name
                                ).model_dump(exclude_none=True),
                            )
                        ]
                    )
124
125
126
127
128
129
130
                    self.current_tool_name_sent = True
                    self.streamed_args_for_tool.append("")
                else:
                    delta = None
            # now we know we're on the same tool call and we're streaming
            # arguments
            else:
131
                prev_arguments = self.get_arguments(
132
133
                    self.prev_tool_call_arr[self.current_tool_id]
                )
134
                cur_arguments = self.get_arguments(tool_call_arr)
135
136
137
138
139
140
141

                # not arguments generated
                if not cur_arguments and not prev_arguments:
                    delta = None
                # will never happen
                elif not cur_arguments and prev_arguments:
                    logger.error(
142
143
                        "INVARIANT - impossible to have arguments reset mid-arguments"
                    )
144
145
146
                    delta = None
                # first time to get parameters
                elif cur_arguments and not prev_arguments:
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
                    cur_arguments_json = json.dumps(cur_arguments, ensure_ascii=False)

                    arguments_delta = cur_arguments_json[
                        : cur_arguments_json.index(delta_text) + len(delta_text)
                    ]
                    delta = DeltaMessage(
                        tool_calls=[
                            DeltaToolCall(
                                index=self.current_tool_id,
                                function=DeltaFunctionCall(
                                    arguments=arguments_delta
                                ).model_dump(exclude_none=True),
                            )
                        ]
                    )
                    self.streamed_args_for_tool[self.current_tool_id] += arguments_delta
163
164
                # both prev and cur parameters, send the increase parameters
                elif cur_arguments and prev_arguments:
165
166
                    cur_args_json = json.dumps(cur_arguments, ensure_ascii=False)
                    prev_args_json = json.dumps(prev_arguments, ensure_ascii=False)
167
168

                    argument_diff = extract_intermediate_diff(
169
170
171
172
173
174
175
176
177
178
179
180
181
182
                        cur_args_json, prev_args_json
                    )

                    delta = DeltaMessage(
                        tool_calls=[
                            DeltaToolCall(
                                index=self.current_tool_id,
                                function=DeltaFunctionCall(
                                    arguments=argument_diff
                                ).model_dump(exclude_none=True),
                            )
                        ]
                    )
                    self.streamed_args_for_tool[self.current_tool_id] += argument_diff
183
184
185
186

            # check to see if the name is defined and has been sent. if so,
            # stream the name - otherwise keep waiting
            # finish by setting old and returning None as base case
187
            tool_call_arr["arguments"] = self.get_arguments(tool_call_arr)
188
189
            self.prev_tool_call_arr = [tool_call_arr]
            return delta
190
191
        except Exception:
            logger.exception("Error trying to handle streaming tool call.")
192
            logger.debug(
193
194
                "Skipping chunk as a result of tool streaming extraction error"
            )
195
196
197
198
199
200
201
202
203
            return None

    def extract_tool_calls(
        self,
        model_output: str,
        request: ChatCompletionRequest,
    ) -> ExtractedToolCallInformation:
        text = model_output
        tools = request.tools
204
205
206
207
        if "<|action_start|><|plugin|>" in text:
            text, action = text.split("<|action_start|><|plugin|>")
            action = action.split("<|action_end|>".strip())[0]
            action = action[action.find("{") :]
208
            action_dict = json.loads(action)
209
210
211
212
213
214
215
            name, parameters = (
                action_dict["name"],
                json.dumps(
                    action_dict.get("parameters", action_dict.get("arguments", {})),
                    ensure_ascii=False,
                ),
            )
216
217

            if not tools or name not in [t.function.name for t in tools]:
218
219
220
                ExtractedToolCallInformation(
                    tools_called=False, tool_calls=[], content=text
                )
221
222

            tool_calls = [
223
                ToolCall(function=FunctionCall(name=name, arguments=parameters))
224
225
226
227
            ]
            return ExtractedToolCallInformation(
                tools_called=True,
                tool_calls=tool_calls,
228
229
                content=text if len(text) > 0 else None,
            )
230

231
232
233
        return ExtractedToolCallInformation(
            tools_called=False, tool_calls=[], content=text
        )