openai_tool_parser.py 4.11 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import json
4
5
6
from collections.abc import Sequence
from typing import TYPE_CHECKING

7
from vllm.entrypoints.openai.parser.harmony_utils import parse_output_into_messages
8
9
10
11
12
13
14
from vllm.entrypoints.openai.protocol import (
    ChatCompletionRequest,
    DeltaMessage,
    ExtractedToolCallInformation,
    FunctionCall,
    ToolCall,
)
15
16
from vllm.logger import init_logger
from vllm.tool_parsers.abstract_tool_parser import (
17
18
    ToolParser,
)
19
20

if TYPE_CHECKING:
21
    from vllm.tokenizers import TokenizerLike
22
else:
23
    TokenizerLike = object
24

25
26
logger = init_logger(__name__)

27
28

class OpenAIToolParser(ToolParser):
29
    def __init__(self, tokenizer: "TokenizerLike"):
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
        super().__init__(tokenizer)

    def extract_tool_calls(
        self,
        model_output: str,
        request: ChatCompletionRequest,
        token_ids: Sequence[int] | None = None,
    ) -> ExtractedToolCallInformation:
        if token_ids is None:
            raise NotImplementedError(
                "OpenAIToolParser requires token IDs and does not support text-based extraction."  # noqa: E501
            )

        parser = parse_output_into_messages(token_ids)
        tool_calls = []
        final_content = None
46
        commentary_content = None
47
48
49

        if len(parser.messages) > 0:
            for msg in parser.messages:
50
51
52
                if len(msg.content) < 1:
                    continue
                msg_text = msg.content[0].text
53
                if msg.recipient and msg.recipient.startswith("functions."):
54
55
56
57
58
59
60
61
62
                    # If no content-type is given assume JSON, as that's the
                    # most common case with gpt-oss models.
                    if not msg.content_type or "json" in msg.content_type:
                        # load and dump the JSON text to check validity and
                        # remove any extra newlines or other odd formatting
                        try:
                            tool_args = json.dumps(json.loads(msg_text))
                        except json.JSONDecodeError:
                            logger.exception(
63
64
                                "Error decoding JSON tool call from response."
                            )
65
66
67
                            tool_args = msg_text
                    else:
                        tool_args = msg_text
68
69
70
71
72
                    tool_calls.append(
                        ToolCall(
                            type="function",
                            function=FunctionCall(
                                name=msg.recipient.split("functions.")[1],
73
                                arguments=tool_args,
74
                            ),
75
76
                        )
                    )
77
                elif msg.channel == "final":
78
                    final_content = msg_text
79
80
                elif msg.channel == "commentary" and not msg.recipient:
                    commentary_content = msg_text
81

82
83
84
85
86
87
88
89
90
        # Extract partial content from the parser state if the generation was truncated
        if parser.current_content:
            if parser.current_channel == "final":
                final_content = parser.current_content
            elif (
                parser.current_channel == "commentary" and not parser.current_recipient
            ):
                commentary_content = parser.current_content

91
92
93
        return ExtractedToolCallInformation(
            tools_called=len(tool_calls) > 0,
            tool_calls=tool_calls,
94
95
96
            # prefer final content over commentary content if both are present
            # commentary content is tool call preambles meant to be shown to the user
            content=final_content or commentary_content,
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
        )

    def extract_tool_calls_streaming(
        self,
        previous_text: str,
        current_text: str,
        delta_text: str,
        previous_token_ids: Sequence[int],
        current_token_ids: Sequence[int],
        delta_token_ids: Sequence[int],
        request: ChatCompletionRequest,
    ) -> DeltaMessage | None:
        raise NotImplementedError(
            "Not being used, manual parsing in serving_chat.py"  # noqa: E501
        )