test_parallel_tool_calls.py 9.23 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
import json

import openai
import pytest

9
10
11
12
13
14
15
from .utils import (
    MESSAGES_ASKING_FOR_PARALLEL_TOOLS,
    MESSAGES_WITH_PARALLEL_TOOL_RESPONSE,
    SEARCH_TOOL,
    WEATHER_TOOL,
    ServerConfig,
)
16
17
18
19
20
21
22


# test: getting the model to generate parallel tool calls (streaming/not)
# when requested. NOTE that not all models may support this, so some exclusions
# may be added in the future. e.g. llama 3.1 models are not designed to support
# parallel tool calls.
@pytest.mark.asyncio
23
24
25
async def test_parallel_tool_calls(
    client: openai.AsyncOpenAI, server_config: ServerConfig
):
26
    if not server_config.get("supports_parallel", True):
27
28
29
30
31
        pytest.skip(
            "The {} model doesn't support parallel tool calls".format(
                server_config["model"]
            )
        )
32

33
34
35
36
37
    models = await client.models.list()
    model_name: str = models.data[0].id
    chat_completion = await client.chat.completions.create(
        messages=MESSAGES_ASKING_FOR_PARALLEL_TOOLS,
        temperature=0,
38
        max_completion_tokens=200,
39
40
        model=model_name,
        tools=[WEATHER_TOOL, SEARCH_TOOL],
41
42
        logprobs=False,
    )
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57

    choice = chat_completion.choices[0]
    stop_reason = chat_completion.choices[0].finish_reason
    non_streamed_tool_calls = chat_completion.choices[0].message.tool_calls

    # make sure 2 tool calls are present
    assert choice.message.role == "assistant"
    assert non_streamed_tool_calls is not None
    assert len(non_streamed_tool_calls) == 2

    for tool_call in non_streamed_tool_calls:
        # make sure the tool includes a function and ID
        assert tool_call.type == "function"
        assert tool_call.function is not None
        assert isinstance(tool_call.id, str)
58
        assert len(tool_call.id) >= 9
59
60
61
62
63
64

        # make sure the weather tool was called correctly
        assert tool_call.function.name == WEATHER_TOOL["function"]["name"]
        assert isinstance(tool_call.function.arguments, str)

        parsed_arguments = json.loads(tool_call.function.arguments)
65
        assert isinstance(parsed_arguments, dict)
66
67
68
69
70
71
72
73
74
75
        assert isinstance(parsed_arguments.get("city"), str)
        assert isinstance(parsed_arguments.get("state"), str)

    assert stop_reason == "tool_calls"

    # make the same request, streaming
    stream = await client.chat.completions.create(
        model=model_name,
        messages=MESSAGES_ASKING_FOR_PARALLEL_TOOLS,
        temperature=0,
76
        max_completion_tokens=200,
77
78
        tools=[WEATHER_TOOL, SEARCH_TOOL],
        logprobs=False,
79
80
        stream=True,
    )
81

82
    role_name: str | None = None
83
84
    finish_reason_count: int = 0

85
86
    tool_call_names: list[str] = []
    tool_call_args: list[str] = []
87
88
89
90
91
92
93
    tool_call_idx: int = -1
    tool_call_id_count: int = 0

    async for chunk in stream:
        # if there's a finish reason make sure it's tools
        if chunk.choices[0].finish_reason:
            finish_reason_count += 1
94
            assert chunk.choices[0].finish_reason == "tool_calls"
95
96
97
98

        # if a role is being streamed make sure it wasn't already set to
        # something else
        if chunk.choices[0].delta.role:
99
100
            assert not role_name or role_name == "assistant"
            role_name = "assistant"
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118

        # if a tool call is streamed make sure there's exactly one
        # (based on the request parameters
        streamed_tool_calls = chunk.choices[0].delta.tool_calls

        if streamed_tool_calls and len(streamed_tool_calls) > 0:
            # make sure only one diff is present - correct even for parallel
            assert len(streamed_tool_calls) == 1
            tool_call = streamed_tool_calls[0]

            # if a new tool is being called, set up empty arguments
            if tool_call.index != tool_call_idx:
                tool_call_idx = tool_call.index
                tool_call_args.append("")

            # if a tool call ID is streamed, make sure one hasn't been already
            if tool_call.id:
                tool_call_id_count += 1
119
                assert isinstance(tool_call.id, str) and (len(tool_call.id) >= 9)
120
121
122
123
124
125
126
127
128
129
130
131
132

            # if parts of the function start being streamed
            if tool_call.function:
                # if the function name is defined, set it. it should be streamed
                # IN ENTIRETY, exactly one time.
                if tool_call.function.name:
                    assert isinstance(tool_call.function.name, str)
                    tool_call_names.append(tool_call.function.name)

                if tool_call.function.arguments:
                    # make sure they're a string and then add them to the list
                    assert isinstance(tool_call.function.arguments, str)

133
                    tool_call_args[tool_call.index] += tool_call.function.arguments
134
135

    assert finish_reason_count == 1
136
    assert role_name == "assistant"
137

138
    assert len(non_streamed_tool_calls) == len(tool_call_names) == len(tool_call_args)
139
140
141
142

    for i in range(2):
        assert non_streamed_tool_calls[i].function.name == tool_call_names[i]
        streamed_args = json.loads(tool_call_args[i])
143
        non_streamed_args = json.loads(non_streamed_tool_calls[i].function.arguments)
144
145
146
147
148
149
        assert streamed_args == non_streamed_args


# test: providing parallel tool calls back to the model to get a response
# (streaming/not)
@pytest.mark.asyncio
150
151
152
async def test_parallel_tool_calls_with_results(
    client: openai.AsyncOpenAI, server_config: ServerConfig
):
153
    if not server_config.get("supports_parallel", True):
154
155
156
157
158
        pytest.skip(
            "The {} model doesn't support parallel tool calls".format(
                server_config["model"]
            )
        )
159

160
161
162
163
164
    models = await client.models.list()
    model_name: str = models.data[0].id
    chat_completion = await client.chat.completions.create(
        messages=MESSAGES_WITH_PARALLEL_TOOL_RESPONSE,
        temperature=0,
165
        max_completion_tokens=200,
166
167
        model=model_name,
        tools=[WEATHER_TOOL, SEARCH_TOOL],
168
169
        logprobs=False,
    )
170
171
172
173
174

    choice = chat_completion.choices[0]

    assert choice.finish_reason != "tool_calls"  # "stop" or "length"
    assert choice.message.role == "assistant"
175
    assert choice.message.tool_calls is None or len(choice.message.tool_calls) == 0
176
177
178
179
180
181
182
    assert choice.message.content is not None
    assert "98" in choice.message.content  # Dallas temp in tool response
    assert "78" in choice.message.content  # Orlando temp in tool response

    stream = await client.chat.completions.create(
        messages=MESSAGES_WITH_PARALLEL_TOOL_RESPONSE,
        temperature=0,
183
        max_completion_tokens=200,
184
185
186
        model=model_name,
        tools=[WEATHER_TOOL, SEARCH_TOOL],
        logprobs=False,
187
188
        stream=True,
    )
189

190
    chunks: list[str] = []
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
    finish_reason_count = 0
    role_sent: bool = False

    async for chunk in stream:
        delta = chunk.choices[0].delta

        if delta.role:
            assert not role_sent
            assert delta.role == "assistant"
            role_sent = True

        if delta.content:
            chunks.append(delta.content)

        if chunk.choices[0].finish_reason is not None:
            finish_reason_count += 1
            assert chunk.choices[0].finish_reason == choice.finish_reason

        assert not delta.tool_calls or len(delta.tool_calls) == 0

    assert role_sent
    assert finish_reason_count == 1
    assert len(chunks)
    assert "".join(chunks) == choice.message.content
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271


@pytest.mark.asyncio
async def test_parallel_tool_calls_false(client: openai.AsyncOpenAI):
    """
    Ensure only one tool call is returned when parallel_tool_calls is False.
    """

    models = await client.models.list()
    model_name: str = models.data[0].id
    chat_completion = await client.chat.completions.create(
        messages=MESSAGES_ASKING_FOR_PARALLEL_TOOLS,
        temperature=0,
        max_completion_tokens=200,
        model=model_name,
        tools=[WEATHER_TOOL, SEARCH_TOOL],
        logprobs=False,
        parallel_tool_calls=False,
    )

    stop_reason = chat_completion.choices[0].finish_reason
    non_streamed_tool_calls = chat_completion.choices[0].message.tool_calls

    # make sure only 1 tool call is present
    assert len(non_streamed_tool_calls) == 1
    assert stop_reason == "tool_calls"

    # make the same request, streaming
    stream = await client.chat.completions.create(
        model=model_name,
        messages=MESSAGES_ASKING_FOR_PARALLEL_TOOLS,
        temperature=0,
        max_completion_tokens=200,
        tools=[WEATHER_TOOL, SEARCH_TOOL],
        logprobs=False,
        parallel_tool_calls=False,
        stream=True,
    )

    finish_reason_count: int = 0
    tool_call_id_count: int = 0

    async for chunk in stream:
        # if there's a finish reason make sure it's tools
        if chunk.choices[0].finish_reason:
            finish_reason_count += 1
            assert chunk.choices[0].finish_reason == "tool_calls"

        streamed_tool_calls = chunk.choices[0].delta.tool_calls
        if streamed_tool_calls and len(streamed_tool_calls) > 0:
            tool_call = streamed_tool_calls[0]
            if tool_call.id:
                tool_call_id_count += 1

    # make sure only 1 streaming tool call is present
    assert tool_call_id_count == 1
    assert finish_reason_count == 1