test_tools_llama.py 7.13 KB
Newer Older
drbh's avatar
drbh committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import pytest


@pytest.fixture(scope="module")
def flash_llama_grammar_tools_handle(launcher):
    with launcher(
        "TinyLlama/TinyLlama-1.1B-Chat-v1.0", num_shard=2, disable_grammar_support=False
    ) as handle:
        yield handle


@pytest.fixture(scope="module")
async def flash_llama_grammar_tools(flash_llama_grammar_tools_handle):
    await flash_llama_grammar_tools_handle.health(300)
    return flash_llama_grammar_tools_handle.client


# tools to be used in the following tests
tools = [
    {
        "type": "function",
        "function": {
            "name": "get_current_weather",
            "description": "Get the current weather",
            "parameters": {
                "type": "object",
                "properties": {
                    "location": {
                        "type": "string",
                        "description": "The city and state, e.g. San Francisco, CA",
                    },
                    "format": {
                        "type": "string",
                        "enum": ["celsius", "fahrenheit"],
                        "description": "The temperature unit to use. Infer this from the users location.",
                    },
                },
                "required": ["location", "format"],
drbh's avatar
drbh committed
39
                "additionalProperties": False,
drbh's avatar
drbh committed
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
            },
        },
    },
    {
        "type": "function",
        "function": {
            "name": "get_n_day_weather_forecast",
            "description": "Get an N-day weather forecast",
            "parameters": {
                "type": "object",
                "properties": {
                    "location": {
                        "type": "string",
                        "description": "The city and state, e.g. San Francisco, CA",
                    },
                    "format": {
                        "type": "string",
                        "enum": ["celsius", "fahrenheit"],
                        "description": "The temperature unit to use. Infer this from the users location.",
                    },
                    "num_days": {
                        "type": "integer",
                        "description": "The number of days to forecast",
                    },
                },
                "required": ["location", "format", "num_days"],
drbh's avatar
drbh committed
66
                "additionalProperties": False,
drbh's avatar
drbh committed
67
68
69
70
71
72
73
74
75
76
77
78
79
            },
        },
    },
]


@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_snapshot):
    response = await flash_llama_grammar_tools.chat(
        max_tokens=100,
        seed=1,
        tools=tools,
drbh's avatar
drbh committed
80
        temperature=0.0,
drbh's avatar
drbh committed
81
82
83
84
85
86
87
88
89
90
91
        messages=[
            {
                "role": "system",
                "content": "Youre a helpful assistant! Answer the users question best you can.",
            },
            {
                "role": "user",
                "content": "What is the weather like in Brooklyn, New York?",
            },
        ],
    )
92
    assert response.choices[0].message.content is None
93
94
    assert response.choices[0].message.tool_calls == [
        {
drbh's avatar
drbh committed
95
            "id": "0",
96
            "type": "function",
97
98
            "function": {
                "description": None,
99
                "name": "get_current_weather",
drbh's avatar
drbh committed
100
                "arguments": {"format": "celsius", "location": "Brooklyn, NY"},
drbh's avatar
drbh committed
101
            },
102
103
        }
    ]
drbh's avatar
drbh committed
104
105
106
107
108
109
110
111
112
113
114
115
    assert response == response_snapshot


@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_auto(
    flash_llama_grammar_tools, response_snapshot
):
    response = await flash_llama_grammar_tools.chat(
        max_tokens=100,
        seed=1,
        tools=tools,
drbh's avatar
drbh committed
116
        temperature=0.0,
drbh's avatar
drbh committed
117
118
119
120
121
122
123
124
125
126
127
128
        tool_choice="auto",
        messages=[
            {
                "role": "system",
                "content": "Youre a helpful assistant! Answer the users question best you can.",
            },
            {
                "role": "user",
                "content": "What is the weather like in Brooklyn, New York?",
            },
        ],
    )
129
    assert response.choices[0].message.content is None
130
131
    assert response.choices[0].message.tool_calls == [
        {
drbh's avatar
drbh committed
132
            "id": "0",
133
            "type": "function",
134
135
            "function": {
                "description": None,
136
                "name": "get_current_weather",
drbh's avatar
drbh committed
137
                "arguments": {"format": "celsius", "location": "Brooklyn, NY"},
drbh's avatar
drbh committed
138
            },
139
140
        }
    ]
141

drbh's avatar
drbh committed
142
143
144
145
146
147
148
149
150
151
152
153
    assert response == response_snapshot


@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_choice(
    flash_llama_grammar_tools, response_snapshot
):
    response = await flash_llama_grammar_tools.chat(
        max_tokens=100,
        seed=1,
        tools=tools,
drbh's avatar
drbh committed
154
        temperature=0.0,
drbh's avatar
drbh committed
155
156
157
158
159
160
161
162
163
164
165
166
        tool_choice="get_current_weather",
        messages=[
            {
                "role": "system",
                "content": "Youre a helpful assistant! Answer the users question best you can.",
            },
            {
                "role": "user",
                "content": "What is the weather like in Brooklyn, New York?",
            },
        ],
    )
167
    assert response.choices[0].message.content is None
168
169
    assert response.choices[0].message.tool_calls == [
        {
drbh's avatar
drbh committed
170
            "id": "0",
171
172
173
            "type": "function",
            "function": {
                "description": None,
174
                "name": "get_current_weather",
drbh's avatar
drbh committed
175
                "arguments": {"format": "celsius", "location": "Brooklyn, NY"},
176
177
178
            },
        }
    ]
179

drbh's avatar
drbh committed
180
181
182
183
184
185
186
187
188
189
190
191
    assert response == response_snapshot


@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_stream(
    flash_llama_grammar_tools, response_snapshot
):
    responses = await flash_llama_grammar_tools.chat(
        max_tokens=100,
        seed=1,
        tools=tools,
drbh's avatar
drbh committed
192
        temperature=0.0,
drbh's avatar
drbh committed
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
        tool_choice="get_current_weather",
        messages=[
            {
                "role": "system",
                "content": "Youre a helpful assistant! Answer the users question best you can.",
            },
            {
                "role": "user",
                "content": "What is the weather like in Paris, France?",
            },
        ],
        stream=True,
    )

    count = 0
    async for response in responses:
        count += 1

drbh's avatar
drbh committed
211
    assert count == 48
drbh's avatar
drbh committed
212
    assert response == response_snapshot
213
214
215
216
217
218
219
220
221


@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_insufficient_information(
    flash_llama_grammar_tools, response_snapshot
):
    responses = await flash_llama_grammar_tools.chat(
        max_tokens=100,
drbh's avatar
drbh committed
222
        seed=24,
223
224
225
226
227
        tools=tools,
        tool_choice="auto",
        messages=[
            {
                "role": "system",
drbh's avatar
drbh committed
228
                "content": "STRICTLY ONLY RESPOND IF THE USER ASKS A WEATHER RELATED QUESTION",
229
230
231
232
233
234
235
236
237
            },
            {
                "role": "user",
                "content": "Tell me a story about 3 sea creatures",
            },
        ],
        stream=False,
    )

238
    assert responses.choices[0].message.content is None
drbh's avatar
drbh committed
239
240
241
    assert (
        responses.choices[0].message.tool_calls[0]["function"]["name"] == "notify_error"
    )
242
    assert responses == response_snapshot