test_tools_llama.py 14.9 KB
Newer Older
drbh's avatar
drbh committed
1
import pytest
2
3
import requests
import json
drbh's avatar
drbh committed
4
5
6
7
8


@pytest.fixture(scope="module")
def flash_llama_grammar_tools_handle(launcher):
    with launcher(
drbh's avatar
drbh committed
9
10
11
        "meta-llama/Meta-Llama-3.1-8B-Instruct",
        num_shard=2,
        disable_grammar_support=False,
drbh's avatar
drbh committed
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
    ) as handle:
        yield handle


@pytest.fixture(scope="module")
async def flash_llama_grammar_tools(flash_llama_grammar_tools_handle):
    await flash_llama_grammar_tools_handle.health(300)
    return flash_llama_grammar_tools_handle.client


# tools to be used in the following tests
tools = [
    {
        "type": "function",
        "function": {
            "name": "get_current_weather",
            "description": "Get the current weather",
            "parameters": {
                "type": "object",
                "properties": {
                    "location": {
                        "type": "string",
                        "description": "The city and state, e.g. San Francisco, CA",
                    },
                    "format": {
                        "type": "string",
                        "enum": ["celsius", "fahrenheit"],
                        "description": "The temperature unit to use. Infer this from the users location.",
                    },
                },
                "required": ["location", "format"],
drbh's avatar
drbh committed
43
                "additionalProperties": False,
drbh's avatar
drbh committed
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
            },
        },
    },
    {
        "type": "function",
        "function": {
            "name": "get_n_day_weather_forecast",
            "description": "Get an N-day weather forecast",
            "parameters": {
                "type": "object",
                "properties": {
                    "location": {
                        "type": "string",
                        "description": "The city and state, e.g. San Francisco, CA",
                    },
                    "format": {
                        "type": "string",
                        "enum": ["celsius", "fahrenheit"],
                        "description": "The temperature unit to use. Infer this from the users location.",
                    },
                    "num_days": {
                        "type": "integer",
                        "description": "The number of days to forecast",
                    },
                },
                "required": ["location", "format", "num_days"],
drbh's avatar
drbh committed
70
                "additionalProperties": False,
drbh's avatar
drbh committed
71
72
73
74
75
76
77
78
79
80
81
82
83
            },
        },
    },
]


@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools(flash_llama_grammar_tools, response_snapshot):
    response = await flash_llama_grammar_tools.chat(
        max_tokens=100,
        seed=1,
        tools=tools,
drbh's avatar
drbh committed
84
        temperature=0.0,
drbh's avatar
drbh committed
85
86
87
88
89
90
91
92
93
94
95
        messages=[
            {
                "role": "system",
                "content": "Youre a helpful assistant! Answer the users question best you can.",
            },
            {
                "role": "user",
                "content": "What is the weather like in Brooklyn, New York?",
            },
        ],
    )
96
    assert response.choices[0].message.content is None
97
98
    assert response.choices[0].message.tool_calls == [
        {
drbh's avatar
drbh committed
99
            "id": "0",
100
            "type": "function",
101
102
            "function": {
                "description": None,
103
                "name": "get_current_weather",
104
                "arguments": {"format": "celsius", "location": "Brooklyn, New York"},
drbh's avatar
drbh committed
105
            },
106
107
        }
    ]
drbh's avatar
drbh committed
108
109
110
111
112
113
114
115
116
117
118
119
    assert response == response_snapshot


@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_auto(
    flash_llama_grammar_tools, response_snapshot
):
    response = await flash_llama_grammar_tools.chat(
        max_tokens=100,
        seed=1,
        tools=tools,
drbh's avatar
drbh committed
120
        temperature=0.0,
drbh's avatar
drbh committed
121
122
123
124
125
126
127
128
129
130
131
132
        tool_choice="auto",
        messages=[
            {
                "role": "system",
                "content": "Youre a helpful assistant! Answer the users question best you can.",
            },
            {
                "role": "user",
                "content": "What is the weather like in Brooklyn, New York?",
            },
        ],
    )
133
    assert response.choices[0].message.content is None
134
135
    assert response.choices[0].message.tool_calls == [
        {
drbh's avatar
drbh committed
136
            "id": "0",
137
            "type": "function",
138
139
            "function": {
                "description": None,
140
                "name": "get_current_weather",
141
                "arguments": {"format": "celsius", "location": "Brooklyn, New York"},
drbh's avatar
drbh committed
142
            },
143
144
        }
    ]
145

drbh's avatar
drbh committed
146
147
148
149
150
151
152
153
154
155
156
157
    assert response == response_snapshot


@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_choice(
    flash_llama_grammar_tools, response_snapshot
):
    response = await flash_llama_grammar_tools.chat(
        max_tokens=100,
        seed=1,
        tools=tools,
drbh's avatar
drbh committed
158
        temperature=0.0,
drbh's avatar
drbh committed
159
160
161
162
163
164
165
166
167
168
169
170
        tool_choice="get_current_weather",
        messages=[
            {
                "role": "system",
                "content": "Youre a helpful assistant! Answer the users question best you can.",
            },
            {
                "role": "user",
                "content": "What is the weather like in Brooklyn, New York?",
            },
        ],
    )
171
    assert response.choices[0].message.content is None
172
173
    assert response.choices[0].message.tool_calls == [
        {
drbh's avatar
drbh committed
174
            "id": "0",
175
176
177
            "type": "function",
            "function": {
                "description": None,
178
                "name": "get_current_weather",
179
                "arguments": {"format": "celsius", "location": "Brooklyn, New York"},
180
181
182
            },
        }
    ]
183

drbh's avatar
drbh committed
184
185
186
187
188
189
190
191
192
193
194
195
    assert response == response_snapshot


@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_stream(
    flash_llama_grammar_tools, response_snapshot
):
    responses = await flash_llama_grammar_tools.chat(
        max_tokens=100,
        seed=1,
        tools=tools,
drbh's avatar
drbh committed
196
        temperature=0.0,
drbh's avatar
drbh committed
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
        tool_choice="get_current_weather",
        messages=[
            {
                "role": "system",
                "content": "Youre a helpful assistant! Answer the users question best you can.",
            },
            {
                "role": "user",
                "content": "What is the weather like in Paris, France?",
            },
        ],
        stream=True,
    )

    count = 0
212
213
    tool_calls_generated = ""
    last_response = None
drbh's avatar
drbh committed
214
215
    async for response in responses:
        count += 1
216
217
218
        tool_calls_generated += response.choices[0].delta.tool_calls.function.arguments
        last_response = response
        assert response.choices[0].delta.content is None
drbh's avatar
drbh committed
219

220
221
    assert (
        tool_calls_generated
222
        == '{"function": {"_name": "get_current_weather", "location": "Paris, France", "format": "celsius"}}<|eot_id|>'
223
    )
drbh's avatar
drbh committed
224
    assert count == 28
225
    assert last_response == response_snapshot
226
227
228
229
230
231
232
233
234


@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_insufficient_information(
    flash_llama_grammar_tools, response_snapshot
):
    responses = await flash_llama_grammar_tools.chat(
        max_tokens=100,
drbh's avatar
drbh committed
235
        seed=24,
236
237
238
        tools=tools,
        tool_choice="auto",
        messages=[
239
240
            {
                "role": "system",
241
                "content": "You're a helpful assistant! Answer the users question best you can.",
242
            },
243
244
            {
                "role": "user",
245
                "content": "Who are you?",
246
247
248
249
250
            },
        ],
        stream=False,
    )

251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
    assert responses.choices[0].message.tool_calls is None
    assert responses.choices[0].message.content == "I am an AI assistant"

    assert responses == response_snapshot


@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_insufficient_information_stream(
    flash_llama_grammar_tools, response_snapshot
):
    responses = await flash_llama_grammar_tools.chat(
        max_tokens=100,
        seed=24,
        tools=tools,
        tool_choice="auto",
        messages=[
            {
                "role": "system",
                "content": "You're a helpful assistant! Answer the users question best you can.",
            },
            {
                "role": "user",
                "content": "Who are you?",
            },
        ],
        stream=True,
    )

    count = 0
    content_generated = ""
    last_response = None
    async for response in responses:
        count += 1
        content_generated += response.choices[0].delta.content
        last_response = response
        assert response.choices[0].delta.tool_calls is None

    assert count == 5
    assert content_generated == "I am an AI assistant"
    assert last_response == response_snapshot


@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_sea_creatures_stream(
    flash_llama_grammar_tools, response_snapshot
):
    responses = await flash_llama_grammar_tools.chat(
        max_tokens=100,
        seed=24,
        tools=tools,
        tool_choice="auto",
        messages=[
            {
                "role": "system",
                "content": "You're a helpful assistant! Answer the users question best you can. If the question is not answerable by the tools, just generate a response.",
            },
            {
                "role": "user",
                "content": "Tell me a story about 3 sea creatures",
            },
        ],
        stream=True,
    )

    count = 0
    content_generated = ""
    last_response = None
    async for response in responses:
        count += 1
        content_generated += response.choices[0].delta.content
        last_response = response
        assert response.choices[0].delta.tool_calls is None

    assert count == 62
327
    assert (
328
329
        content_generated
        == "Once upon a time, in the ocean, there lived three sea creatures. There was a wise old octopus named Bob, a mischievous seagull named Sam, and a gentle sea turtle named Luna. They all lived together in a beautiful coral reef, surrounded by colorful fish and swaying sea fans"
330
    )
331
    assert last_response == response_snapshot
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368


@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_sea_creatures_stream_required(
    flash_llama_grammar_tools, response_snapshot
):
    responses = await flash_llama_grammar_tools.chat(
        max_tokens=100,
        seed=24,
        tools=tools,
        tool_choice="required",
        messages=[
            {
                "role": "system",
                "content": "You're a helpful assistant! Answer the users question best you can. If the question is not answerable by the tools, just generate a response.",
            },
            {
                "role": "user",
                "content": "Tell me a story about 3 sea creatures",
            },
        ],
        stream=True,
    )

    count = 0
    tool_calls_generated = ""
    last_response = None
    async for response in responses:
        count += 1
        assert response.choices[0].delta.content is None
        tool_calls_generated += response.choices[0].delta.tool_calls.function.arguments
        last_response = response

    assert count == 29
    assert (
        tool_calls_generated
369
        == '{"function": {"_name": "get_current_weather", "location": "San Francisco, CA", "format": "celsius"}}<|eot_id|>'
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
    )
    assert last_response == response_snapshot


@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_sea_creatures_stream_none(
    flash_llama_grammar_tools, response_snapshot
):
    responses = await flash_llama_grammar_tools.chat(
        max_tokens=100,
        seed=24,
        tools=tools,
        tool_choice="none",
        messages=[
            {
                "role": "system",
                "content": "You're a helpful assistant! Answer the users question best you can. If the question is not answerable by the tools, just generate a response.",
            },
            {
                "role": "user",
                "content": "Tell me a story about 3 sea creatures",
            },
        ],
        stream=True,
    )

    count = 0
    content_generated = ""
    last_response = None
    async for response in responses:
        count += 1
        content_generated += response.choices[0].delta.content
        last_response = response
        assert response.choices[0].delta.tool_calls is None

    assert count == 100
    print(content_generated)
    assert (
        content_generated
        == "Once upon a time, in a vibrant ocean filled with coral reefs and schools of shimmering fish, lived three dear friends: Luna the sea turtle, Finley the friendly fish, and Crusty the wise crab.\n\nLuna was the oldest of the three. She had traveled the world, exploring hidden caves and shipwrecks, and collecting sparkling shells and shiny pebbles. Her shell was a beautiful mosaic of blues and greens, and her gentle eyes twinkled with the secrets of the deep"
    )
    assert last_response == response_snapshot


@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_grammar_tools_sea_creatures_stream_function_object(
    flash_llama_grammar_tools, response_snapshot
):
    # using `requests` to send the request until the client library supports tool_choice as a function object
    responses = requests.post(
        f"{flash_llama_grammar_tools.base_url}/v1/chat/completions",
        headers=flash_llama_grammar_tools.headers,
        json={
            "model": "tgi",
            "messages": [
                {
                    "role": "system",
                    "content": "You're a helpful assistant! Answer the users question best you can. If the question is not answerable by the tools, just generate a response.",
                },
                {
                    "role": "user",
                    "content": "Tell me a story about 3 sea creatures",
                },
            ],
            "tools": tools,
            "tool_choice": {
                "type": "function",
                "function": {"name": "get_n_day_weather_forecast"},
            },
            "seed": 24,
            "max_tokens": 100,
            "stream": True,
        },
        stream=True,
    )
    # iterate over the response in chunks
    count = 0
    tool_calls_generated = ""
    last_response = None
    for chunk in responses.iter_content(chunk_size=1024):
        if chunk:
            count += 1
            # remove the "data: " prefix, trailing newline, and split the chunk into individual lines
            lines = chunk.decode("utf-8").replace("data: ", "").rstrip("\n").split("\n")
            for line in lines:
                if line == "[DONE]":
                    break
                response = json.loads(line)
                tool_calls_generated += response["choices"][0]["delta"]["tool_calls"][
                    "function"
                ]["arguments"]
                last_response = response

    assert count == 39
    assert (
        tool_calls_generated
468
        == '{"function": {"_name": "get_n_day_weather_forecast", "location": "San Francisco, CA", "format": "celsius", "num_days":3}}<|eot_id|>'
469
470
    )
    assert last_response == response_snapshot