test_tool_calls.py 7.11 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
import json

import openai
import pytest

9
10
11
12
from .utils import (
    MESSAGES_ASKING_FOR_TOOLS,
    MESSAGES_WITH_TOOL_RESPONSE,
    SEARCH_TOOL,
13
    SEED,
14
15
    WEATHER_TOOL,
)
16
17
18
19
20
21
22
23
24
25
26


# test: request a chat completion that should return tool calls, so we know they
# are parsable
@pytest.mark.asyncio
async def test_tool_call_and_choice(client: openai.AsyncOpenAI):
    models = await client.models.list()
    model_name: str = models.data[0].id
    chat_completion = await client.chat.completions.create(
        messages=MESSAGES_ASKING_FOR_TOOLS,
        temperature=0,
27
        max_completion_tokens=100,
28
29
        model=model_name,
        tools=[WEATHER_TOOL, SEARCH_TOOL],
30
        logprobs=False,
31
        seed=SEED,
32
    )
33
34
35
36
37
38

    choice = chat_completion.choices[0]
    stop_reason = chat_completion.choices[0].finish_reason
    tool_calls = chat_completion.choices[0].message.tool_calls

    # make sure a tool call is present
39
    assert choice.message.role == "assistant"
40
41
    assert tool_calls is not None
    assert len(tool_calls) == 1
42
    assert tool_calls[0].type == "function"
43
44
    assert tool_calls[0].function is not None
    assert isinstance(tool_calls[0].id, str)
45
    assert len(tool_calls[0].id) >= 9
46
47
48
49
50
51
52
53

    # make sure the weather tool was called (classic example) with arguments
    assert tool_calls[0].function.name == WEATHER_TOOL["function"]["name"]
    assert tool_calls[0].function.arguments is not None
    assert isinstance(tool_calls[0].function.arguments, str)

    # make sure the arguments parse properly
    parsed_arguments = json.loads(tool_calls[0].function.arguments)
54
    assert isinstance(parsed_arguments, dict)
55
56
57
58
59
60
61
    assert isinstance(parsed_arguments.get("city"), str)
    assert isinstance(parsed_arguments.get("state"), str)
    assert parsed_arguments.get("city") == "Dallas"
    assert parsed_arguments.get("state") == "TX"

    assert stop_reason == "tool_calls"

62
    function_name: str | None = None
63
    function_args_str: str = ""
64
65
    tool_call_id: str | None = None
    role_name: str | None = None
66
67
68
69
70
71
72
    finish_reason_count: int = 0

    # make the same request, streaming
    stream = await client.chat.completions.create(
        model=model_name,
        messages=MESSAGES_ASKING_FOR_TOOLS,
        temperature=0,
73
        max_completion_tokens=100,
74
75
        tools=[WEATHER_TOOL, SEARCH_TOOL],
        logprobs=False,
76
        seed=SEED,
77
78
        stream=True,
    )
79
80
81
82
83
84

    async for chunk in stream:
        assert chunk.choices[0].index == 0

        if chunk.choices[0].finish_reason:
            finish_reason_count += 1
85
            assert chunk.choices[0].finish_reason == "tool_calls"
86
87
88
89

        # if a role is being streamed make sure it wasn't already set to
        # something else
        if chunk.choices[0].delta.role:
90
91
            assert not role_name or role_name == "assistant"
            role_name = "assistant"
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118

        # if a tool call is streamed make sure there's exactly one
        # (based on the request parameters
        streamed_tool_calls = chunk.choices[0].delta.tool_calls

        if streamed_tool_calls and len(streamed_tool_calls) > 0:
            assert len(streamed_tool_calls) == 1
            tool_call = streamed_tool_calls[0]

            # if a tool call ID is streamed, make sure one hasn't been already
            if tool_call.id:
                assert not tool_call_id
                tool_call_id = tool_call.id

            # if parts of the function start being streamed
            if tool_call.function:
                # if the function name is defined, set it. it should be streamed
                # IN ENTIRETY, exactly one time.
                if tool_call.function.name:
                    assert function_name is None
                    assert isinstance(tool_call.function.name, str)
                    function_name = tool_call.function.name
                if tool_call.function.arguments:
                    assert isinstance(tool_call.function.arguments, str)
                    function_args_str += tool_call.function.arguments

    assert finish_reason_count == 1
119
    assert role_name == "assistant"
120
    assert isinstance(tool_call_id, str) and (len(tool_call_id) >= 9)
121
122
123
124
125
126
127
128

    # validate the name and arguments
    assert function_name == WEATHER_TOOL["function"]["name"]
    assert function_name == tool_calls[0].function.name
    assert isinstance(function_args_str, str)

    # validate arguments
    streamed_args = json.loads(function_args_str)
129
    assert isinstance(streamed_args, dict)
130
131
132
133
134
135
136
137
138
139
    assert isinstance(streamed_args.get("city"), str)
    assert isinstance(streamed_args.get("state"), str)
    assert streamed_args.get("city") == "Dallas"
    assert streamed_args.get("state") == "TX"

    # make sure everything matches non-streaming except for ID
    assert function_name == tool_calls[0].function.name
    assert choice.message.role == role_name
    assert choice.message.tool_calls[0].function.name == function_name

140
    # compare streamed with non-streamed args dict-wise, not string-wise
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
    # because character-to-character comparison might not work e.g. the tool
    # call parser adding extra spaces or something like that. we care about the
    # dicts matching not byte-wise match
    assert parsed_arguments == streamed_args


# test: providing tools and results back to model to get a non-tool response
# (streaming/not)
@pytest.mark.asyncio
async def test_tool_call_with_results(client: openai.AsyncOpenAI):
    models = await client.models.list()
    model_name: str = models.data[0].id
    chat_completion = await client.chat.completions.create(
        messages=MESSAGES_WITH_TOOL_RESPONSE,
        temperature=0,
156
        max_completion_tokens=100,
157
158
        model=model_name,
        tools=[WEATHER_TOOL, SEARCH_TOOL],
159
        logprobs=False,
160
        seed=SEED,
161
    )
162
163
164
165
166

    choice = chat_completion.choices[0]

    assert choice.finish_reason != "tool_calls"  # "stop" or "length"
    assert choice.message.role == "assistant"
167
    assert choice.message.tool_calls is None or len(choice.message.tool_calls) == 0
168
169
170
171
172
173
    assert choice.message.content is not None
    assert "98" in choice.message.content  # the temperature from the response

    stream = await client.chat.completions.create(
        messages=MESSAGES_WITH_TOOL_RESPONSE,
        temperature=0,
174
        max_completion_tokens=100,
175
176
177
        model=model_name,
        tools=[WEATHER_TOOL, SEARCH_TOOL],
        logprobs=False,
178
        seed=SEED,
179
180
        stream=True,
    )
181

182
    chunks: list[str] = []
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
    finish_reason_count = 0
    role_sent: bool = False

    async for chunk in stream:
        delta = chunk.choices[0].delta

        if delta.role:
            assert not role_sent
            assert delta.role == "assistant"
            role_sent = True

        if delta.content:
            chunks.append(delta.content)

        if chunk.choices[0].finish_reason is not None:
            finish_reason_count += 1
            assert chunk.choices[0].finish_reason == choice.finish_reason

        assert not delta.tool_calls or len(delta.tool_calls) == 0

    assert role_sent
    assert finish_reason_count == 1
    assert len(chunks)
    assert "".join(chunks) == choice.message.content