utils.py 6.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from collections.abc import Iterable
5

6
7
8
9
10
11
12
from vllm.entrypoints.openai.protocol import (
    ChatCompletionRequest,
    DeltaMessage,
    ExtractedToolCallInformation,
    FunctionCall,
    ToolCall,
)
13
from vllm.entrypoints.openai.tool_parsers import ToolParser
14
from vllm.tokenizers import TokenizerLike
15
16
17
18


class StreamingToolReconstructor:
    def __init__(self, assert_one_tool_per_delta: bool = True):
19
        self.tool_calls: list[ToolCall] = []
20
21
22
23
24
25
26
27
        self.other_content: str = ""
        self._assert_one_tool_per_delta = assert_one_tool_per_delta

    def append_delta(self, delta: DeltaMessage):
        if delta.content is not None:
            self.other_content += delta.content
        else:
            assert delta.tool_calls, (
28
29
                "Streaming results should have either content or tool calls (or both)"
            )
30
31
32
33
34
35
        if self._assert_one_tool_per_delta:
            # Note: This isn't strictly required by the API and may not be
            # possible to adhere to depending on the token space and number of
            # tokens per streamed response from the model, but it is required
            # by tool_use tests, so we enforce it here by default also.
            assert len(delta.tool_calls) < 2, (
36
37
                "Streaming should include only one tool call per update."
            )
38
        for call_delta in delta.tool_calls:
39
            assert call_delta.type is None or call_delta.type == "function", (
40
                "Streaming tool calls should only emit function calls. Got "
41
42
43
44
45
46
47
                f"{call_delta.type}"
            )
            current_tool_call = (
                self.tool_calls[call_delta.index]
                if call_delta.index < len(self.tool_calls)
                else None
            )
48
            if current_tool_call:
49
                assert not call_delta.function.name, (
50
                    "Streaming tool calls should emit the full function name "
51
52
53
                    f"exactly once. Got {call_delta.function.name}"
                )
                assert not call_delta.id, (
54
                    "Streaming tool calls must emit function id only once. Got "
55
56
57
                    f"{call_delta.id}"
                )
                assert call_delta.index == len(self.tool_calls) - 1, (
58
                    f"Incorrect index for tool delta. Got {call_delta.index}, "
59
60
61
                    f"expected {len(self.tool_calls) - 1}"
                )
                current_tool_call.function.arguments += call_delta.function.arguments
62
63
            else:
                assert call_delta.id is not None, (
64
65
                    "Streaming tool calls must have an id on first appearance"
                )
66
                assert call_delta.function.name is not None, (
67
68
                    "Streaming tool calls must have a function name on first appearance"
                )
69
70
                assert call_delta.index == len(self.tool_calls), (
                    f"Incorrect index for tool delta. Got {call_delta.index}, "
71
72
                    f"expected {len(self.tool_calls)}"
                )
73
                self.tool_calls.append(
74
75
76
77
78
79
80
81
                    ToolCall(
                        id=call_delta.id,
                        function=FunctionCall(
                            name=call_delta.function.name,
                            arguments=call_delta.function.arguments or "",
                        ),
                    )
                )
82
83
84
85
86


def run_tool_extraction(
    tool_parser: ToolParser,
    model_output: str,
87
    request: ChatCompletionRequest | None = None,
88
89
    streaming: bool = False,
    assert_one_tool_per_delta: bool = True,
90
) -> tuple[str | None, list[ToolCall]]:
91
92
93
94
95
    if streaming:
        reconstructor = run_tool_extraction_streaming(
            tool_parser,
            model_output,
            request,
96
97
            assert_one_tool_per_delta=assert_one_tool_per_delta,
        )
98
99
        return reconstructor.other_content or None, reconstructor.tool_calls
    else:
100
        extracted = run_tool_extraction_nonstreaming(tool_parser, model_output, request)
101
102
103
104
105
106
107
        assert extracted.tools_called == bool(extracted.tool_calls)
        return extracted.content, extracted.tool_calls


def run_tool_extraction_nonstreaming(
    tool_parser: ToolParser,
    model_output: str,
108
    request: ChatCompletionRequest | None = None,
109
110
111
112
113
) -> ExtractedToolCallInformation:
    request = request or ChatCompletionRequest(messages=[], model="test-model")
    return tool_parser.extract_tool_calls(model_output, request)


114
def split_string_into_token_deltas(tokenizer: TokenizerLike, text: str) -> list[str]:
115
116
117
118
119
120
121
122
123
124
125
126
127
128
    # Split a string into a series of deltas using the provided tokenizer. Each
    # delta will be the string equivalent of a single token.
    token_ids = tokenizer.encode(text, add_special_tokens=False)
    previously_decoded_text = ""
    deltas = []
    for i in range(1, len(token_ids) + 1):
        current_tokens = token_ids[:i]
        current_text = tokenizer.decode(current_tokens)
        new_text = current_text[len(previously_decoded_text) :]
        previously_decoded_text = current_text
        deltas.append(new_text)
    return deltas


129
130
131
def run_tool_extraction_streaming(
    tool_parser: ToolParser,
    model_deltas: Iterable[str],
132
    request: ChatCompletionRequest | None = None,
133
134
    assert_one_tool_per_delta: bool = True,
) -> StreamingToolReconstructor:
135
136
137
138
139
    if isinstance(model_deltas, str):
        model_deltas = split_string_into_token_deltas(
            tool_parser.model_tokenizer, model_deltas
        )

140
141
    request = request or ChatCompletionRequest(messages=[], model="test-model")
    reconstructor = StreamingToolReconstructor(
142
143
        assert_one_tool_per_delta=assert_one_tool_per_delta
    )
144
    previous_text = ""
145
    previous_tokens: list[int] = []
146
147
148
149
150
151
152
153
154
    for delta in model_deltas:
        token_delta = [
            tool_parser.vocab.get(token)
            for token in tool_parser.model_tokenizer.tokenize(delta)
            if token in tool_parser.vocab
        ]
        current_text = previous_text + delta
        current_tokens = previous_tokens + token_delta
        delta_message = tool_parser.extract_tool_calls_streaming(
155
156
157
158
159
160
161
162
            previous_text,
            current_text,
            delta,
            previous_tokens,
            current_tokens,
            token_delta,
            request,
        )
163
164
165
166
167
        if delta_message is not None:
            reconstructor.append_delta(delta_message)
        previous_text = current_text
        previous_tokens = current_tokens
    return reconstructor