test_parallel_tool_calls.py 9.35 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
import json

import openai
import pytest

9
10
11
12
from .utils import (
    MESSAGES_ASKING_FOR_PARALLEL_TOOLS,
    MESSAGES_WITH_PARALLEL_TOOL_RESPONSE,
    SEARCH_TOOL,
13
    SEED,
14
15
16
    WEATHER_TOOL,
    ServerConfig,
)
17
18
19
20
21
22
23


# test: getting the model to generate parallel tool calls (streaming/not)
# when requested. NOTE that not all models may support this, so some exclusions
# may be added in the future. e.g. llama 3.1 models are not designed to support
# parallel tool calls.
@pytest.mark.asyncio
24
25
26
async def test_parallel_tool_calls(
    client: openai.AsyncOpenAI, server_config: ServerConfig
):
27
    if not server_config.get("supports_parallel", True):
28
29
30
31
32
        pytest.skip(
            "The {} model doesn't support parallel tool calls".format(
                server_config["model"]
            )
        )
33

34
35
36
37
38
    models = await client.models.list()
    model_name: str = models.data[0].id
    chat_completion = await client.chat.completions.create(
        messages=MESSAGES_ASKING_FOR_PARALLEL_TOOLS,
        temperature=0,
39
        max_completion_tokens=200,
40
41
        model=model_name,
        tools=[WEATHER_TOOL, SEARCH_TOOL],
42
        logprobs=False,
43
        seed=SEED,
44
    )
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59

    choice = chat_completion.choices[0]
    stop_reason = chat_completion.choices[0].finish_reason
    non_streamed_tool_calls = chat_completion.choices[0].message.tool_calls

    # make sure 2 tool calls are present
    assert choice.message.role == "assistant"
    assert non_streamed_tool_calls is not None
    assert len(non_streamed_tool_calls) == 2

    for tool_call in non_streamed_tool_calls:
        # make sure the tool includes a function and ID
        assert tool_call.type == "function"
        assert tool_call.function is not None
        assert isinstance(tool_call.id, str)
60
        assert len(tool_call.id) >= 9
61
62
63
64
65
66

        # make sure the weather tool was called correctly
        assert tool_call.function.name == WEATHER_TOOL["function"]["name"]
        assert isinstance(tool_call.function.arguments, str)

        parsed_arguments = json.loads(tool_call.function.arguments)
67
        assert isinstance(parsed_arguments, dict)
68
69
70
71
72
73
74
75
76
77
        assert isinstance(parsed_arguments.get("city"), str)
        assert isinstance(parsed_arguments.get("state"), str)

    assert stop_reason == "tool_calls"

    # make the same request, streaming
    stream = await client.chat.completions.create(
        model=model_name,
        messages=MESSAGES_ASKING_FOR_PARALLEL_TOOLS,
        temperature=0,
78
        max_completion_tokens=200,
79
80
        tools=[WEATHER_TOOL, SEARCH_TOOL],
        logprobs=False,
81
        seed=SEED,
82
83
        stream=True,
    )
84

85
    role_name: str | None = None
86
87
    finish_reason_count: int = 0

88
89
    tool_call_names: list[str] = []
    tool_call_args: list[str] = []
90
91
92
93
94
95
96
    tool_call_idx: int = -1
    tool_call_id_count: int = 0

    async for chunk in stream:
        # if there's a finish reason make sure it's tools
        if chunk.choices[0].finish_reason:
            finish_reason_count += 1
97
            assert chunk.choices[0].finish_reason == "tool_calls"
98
99
100
101

        # if a role is being streamed make sure it wasn't already set to
        # something else
        if chunk.choices[0].delta.role:
102
103
            assert not role_name or role_name == "assistant"
            role_name = "assistant"
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121

        # if a tool call is streamed make sure there's exactly one
        # (based on the request parameters
        streamed_tool_calls = chunk.choices[0].delta.tool_calls

        if streamed_tool_calls and len(streamed_tool_calls) > 0:
            # make sure only one diff is present - correct even for parallel
            assert len(streamed_tool_calls) == 1
            tool_call = streamed_tool_calls[0]

            # if a new tool is being called, set up empty arguments
            if tool_call.index != tool_call_idx:
                tool_call_idx = tool_call.index
                tool_call_args.append("")

            # if a tool call ID is streamed, make sure one hasn't been already
            if tool_call.id:
                tool_call_id_count += 1
122
                assert isinstance(tool_call.id, str) and (len(tool_call.id) >= 9)
123
124
125
126
127
128
129
130
131
132
133
134
135

            # if parts of the function start being streamed
            if tool_call.function:
                # if the function name is defined, set it. it should be streamed
                # IN ENTIRETY, exactly one time.
                if tool_call.function.name:
                    assert isinstance(tool_call.function.name, str)
                    tool_call_names.append(tool_call.function.name)

                if tool_call.function.arguments:
                    # make sure they're a string and then add them to the list
                    assert isinstance(tool_call.function.arguments, str)

136
                    tool_call_args[tool_call.index] += tool_call.function.arguments
137
138

    assert finish_reason_count == 1
139
    assert role_name == "assistant"
140

141
    assert len(non_streamed_tool_calls) == len(tool_call_names) == len(tool_call_args)
142
143
144
145

    for i in range(2):
        assert non_streamed_tool_calls[i].function.name == tool_call_names[i]
        streamed_args = json.loads(tool_call_args[i])
146
        non_streamed_args = json.loads(non_streamed_tool_calls[i].function.arguments)
147
148
149
150
151
152
        assert streamed_args == non_streamed_args


# test: providing parallel tool calls back to the model to get a response
# (streaming/not)
@pytest.mark.asyncio
153
154
155
async def test_parallel_tool_calls_with_results(
    client: openai.AsyncOpenAI, server_config: ServerConfig
):
156
    if not server_config.get("supports_parallel", True):
157
158
159
160
161
        pytest.skip(
            "The {} model doesn't support parallel tool calls".format(
                server_config["model"]
            )
        )
162

163
164
165
166
167
    models = await client.models.list()
    model_name: str = models.data[0].id
    chat_completion = await client.chat.completions.create(
        messages=MESSAGES_WITH_PARALLEL_TOOL_RESPONSE,
        temperature=0,
168
        max_completion_tokens=200,
169
170
        model=model_name,
        tools=[WEATHER_TOOL, SEARCH_TOOL],
171
        logprobs=False,
172
        seed=SEED,
173
    )
174
175
176
177
178

    choice = chat_completion.choices[0]

    assert choice.finish_reason != "tool_calls"  # "stop" or "length"
    assert choice.message.role == "assistant"
179
    assert choice.message.tool_calls is None or len(choice.message.tool_calls) == 0
180
181
182
183
184
185
186
    assert choice.message.content is not None
    assert "98" in choice.message.content  # Dallas temp in tool response
    assert "78" in choice.message.content  # Orlando temp in tool response

    stream = await client.chat.completions.create(
        messages=MESSAGES_WITH_PARALLEL_TOOL_RESPONSE,
        temperature=0,
187
        max_completion_tokens=200,
188
189
190
        model=model_name,
        tools=[WEATHER_TOOL, SEARCH_TOOL],
        logprobs=False,
191
        seed=SEED,
192
193
        stream=True,
    )
194

195
    chunks: list[str] = []
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
    finish_reason_count = 0
    role_sent: bool = False

    async for chunk in stream:
        delta = chunk.choices[0].delta

        if delta.role:
            assert not role_sent
            assert delta.role == "assistant"
            role_sent = True

        if delta.content:
            chunks.append(delta.content)

        if chunk.choices[0].finish_reason is not None:
            finish_reason_count += 1
            assert chunk.choices[0].finish_reason == choice.finish_reason

        assert not delta.tool_calls or len(delta.tool_calls) == 0

    assert role_sent
    assert finish_reason_count == 1
    assert len(chunks)
    assert "".join(chunks) == choice.message.content
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236


@pytest.mark.asyncio
async def test_parallel_tool_calls_false(client: openai.AsyncOpenAI):
    """
    Ensure only one tool call is returned when parallel_tool_calls is False.
    """

    models = await client.models.list()
    model_name: str = models.data[0].id
    chat_completion = await client.chat.completions.create(
        messages=MESSAGES_ASKING_FOR_PARALLEL_TOOLS,
        temperature=0,
        max_completion_tokens=200,
        model=model_name,
        tools=[WEATHER_TOOL, SEARCH_TOOL],
        logprobs=False,
237
        seed=SEED,
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
        parallel_tool_calls=False,
    )

    stop_reason = chat_completion.choices[0].finish_reason
    non_streamed_tool_calls = chat_completion.choices[0].message.tool_calls

    # make sure only 1 tool call is present
    assert len(non_streamed_tool_calls) == 1
    assert stop_reason == "tool_calls"

    # make the same request, streaming
    stream = await client.chat.completions.create(
        model=model_name,
        messages=MESSAGES_ASKING_FOR_PARALLEL_TOOLS,
        temperature=0,
        max_completion_tokens=200,
        tools=[WEATHER_TOOL, SEARCH_TOOL],
        logprobs=False,
256
        seed=SEED,
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
        parallel_tool_calls=False,
        stream=True,
    )

    finish_reason_count: int = 0
    tool_call_id_count: int = 0

    async for chunk in stream:
        # if there's a finish reason make sure it's tools
        if chunk.choices[0].finish_reason:
            finish_reason_count += 1
            assert chunk.choices[0].finish_reason == "tool_calls"

        streamed_tool_calls = chunk.choices[0].delta.tool_calls
        if streamed_tool_calls and len(streamed_tool_calls) > 0:
            tool_call = streamed_tool_calls[0]
            if tool_call.id:
                tool_call_id_count += 1

    # make sure only 1 streaming tool call is present
    assert tool_call_id_count == 1
    assert finish_reason_count == 1