test_xlam_tool_parser.py 18.7 KB
Newer Older
Zuxin's avatar
Zuxin committed
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
Zuxin's avatar
Zuxin committed
3
4

import json
5
from collections.abc import Generator
Zuxin's avatar
Zuxin committed
6
7
8

import pytest

9
10
11
12
13
14
from vllm.entrypoints.openai.protocol import (
    ChatCompletionRequest,
    DeltaMessage,
    FunctionCall,
    ToolCall,
)
Zuxin's avatar
Zuxin committed
15
from vllm.entrypoints.openai.tool_parsers import xLAMToolParser
16
from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally
17
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
Zuxin's avatar
Zuxin committed
18

19
20
pytestmark = pytest.mark.cpu_test

Zuxin's avatar
Zuxin committed
21
22
23
24
25
26
27
28
29
30
31
32
33
34
# Use a common model that is likely to be available
MODEL = "Salesforce/Llama-xLAM-2-8B-fc-r"


@pytest.fixture(scope="module")
def xlam_tokenizer():
    return get_tokenizer(tokenizer_name=MODEL)


@pytest.fixture
def xlam_tool_parser(xlam_tokenizer):
    return xLAMToolParser(xlam_tokenizer)


35
36
37
def assert_tool_calls(
    actual_tool_calls: list[ToolCall], expected_tool_calls: list[ToolCall]
):
Zuxin's avatar
Zuxin committed
38
39
    assert len(actual_tool_calls) == len(expected_tool_calls)

40
41
42
    for actual_tool_call, expected_tool_call in zip(
        actual_tool_calls, expected_tool_calls
    ):
Zuxin's avatar
Zuxin committed
43
44
45
46
47
48
49
        assert isinstance(actual_tool_call.id, str)
        assert len(actual_tool_call.id) > 16

        assert actual_tool_call.type == "function"
        assert actual_tool_call.function == expected_tool_call.function


50
51
52
53
def stream_delta_message_generator(
    xlam_tool_parser: xLAMToolParser,
    xlam_tokenizer: AnyTokenizer,
    model_output: str,
54
    request: ChatCompletionRequest | None = None,
55
) -> Generator[DeltaMessage, None, None]:
56
    all_token_ids = xlam_tokenizer.encode(model_output, add_special_tokens=False)
57
58
59
60
61
62
63
64

    previous_text = ""
    previous_tokens = None
    prefix_offset = 0
    read_offset = 0
    for i, delta_token in enumerate(all_token_ids):
        delta_token_ids = [delta_token]
        previous_token_ids = all_token_ids[:i]
65
66
67
68
69
70
71
72
73
74
75
76
77
        current_token_ids = all_token_ids[: i + 1]

        (new_tokens, delta_text, new_prefix_offset, new_read_offset) = (
            detokenize_incrementally(
                tokenizer=xlam_tokenizer,
                all_input_ids=current_token_ids,
                prev_tokens=previous_tokens,
                prefix_offset=prefix_offset,
                read_offset=read_offset,
                skip_special_tokens=False,
                spaces_between_special_tokens=True,
            )
        )
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93

        current_text = previous_text + delta_text

        delta_message = xlam_tool_parser.extract_tool_calls_streaming(
            previous_text,
            current_text,
            delta_text,
            previous_token_ids,
            current_token_ids,
            delta_token_ids,
            request=request,
        )
        if delta_message:
            yield delta_message

        previous_text = current_text
94
95
96
        previous_tokens = (
            previous_tokens + new_tokens if previous_tokens else new_tokens
        )
97
98
99
100
        prefix_offset = new_prefix_offset
        read_offset = new_read_offset


Zuxin's avatar
Zuxin committed
101
102
103
def test_extract_tool_calls_no_tools(xlam_tool_parser):
    model_output = "This is a test"
    extracted_tool_calls = xlam_tool_parser.extract_tool_calls(
104
105
        model_output, request=None
    )  # type: ignore[arg-type]
Zuxin's avatar
Zuxin committed
106
107
108
109
110
111
112
113
114
115
116
    assert not extracted_tool_calls.tools_called
    assert extracted_tool_calls.tool_calls == []
    assert extracted_tool_calls.content == model_output


@pytest.mark.parametrize(
    ids=[
        "parallel_tool_calls",
        "single_tool_with_think_tag",
        "single_tool_with_json_code_block",
        "single_tool_with_tool_calls_tag",
117
        "single_tool_with_tool_call_xml_tags",
Zuxin's avatar
Zuxin committed
118
119
120
121
122
123
    ],
    argnames=["model_output", "expected_tool_calls", "expected_content"],
    argvalues=[
        (
            """[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}, {"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}]""",  # noqa: E501
            [
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
                ToolCall(
                    function=FunctionCall(
                        name="get_current_weather",
                        arguments=json.dumps(
                            {
                                "city": "Dallas",
                                "state": "TX",
                                "unit": "fahrenheit",
                            }
                        ),
                    )
                ),
                ToolCall(
                    function=FunctionCall(
                        name="get_current_weather",
                        arguments=json.dumps(
                            {
                                "city": "Orlando",
                                "state": "FL",
                                "unit": "fahrenheit",
                            }
                        ),
                    )
                ),
Zuxin's avatar
Zuxin committed
148
149
150
151
152
153
            ],
            None,
        ),
        (
            """<think>I'll help you with that.</think>[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]""",  # noqa: E501
            [
154
155
156
157
158
159
160
161
162
163
164
165
                ToolCall(
                    function=FunctionCall(
                        name="get_current_weather",
                        arguments=json.dumps(
                            {
                                "city": "Dallas",
                                "state": "TX",
                                "unit": "fahrenheit",
                            }
                        ),
                    )
                )
Zuxin's avatar
Zuxin committed
166
167
168
169
170
171
            ],
            "<think>I'll help you with that.</think>",
        ),
        (
            """I'll help you with that.\n```json\n[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]\n```""",  # noqa: E501
            [
172
173
174
175
176
177
178
179
180
181
182
183
                ToolCall(
                    function=FunctionCall(
                        name="get_current_weather",
                        arguments=json.dumps(
                            {
                                "city": "Dallas",
                                "state": "TX",
                                "unit": "fahrenheit",
                            }
                        ),
                    )
                )
Zuxin's avatar
Zuxin committed
184
185
186
187
188
189
            ],
            "I'll help you with that.",
        ),
        (
            """I'll check the weather for you.[TOOL_CALLS][{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]""",  # noqa: E501
            [
190
191
192
193
194
195
196
197
198
199
200
201
                ToolCall(
                    function=FunctionCall(
                        name="get_current_weather",
                        arguments=json.dumps(
                            {
                                "city": "Dallas",
                                "state": "TX",
                                "unit": "fahrenheit",
                            }
                        ),
                    )
                )
Zuxin's avatar
Zuxin committed
202
203
204
            ],
            "I'll check the weather for you.",
        ),
205
206
207
        (
            """I'll help you check the weather.<tool_call>[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]</tool_call>""",  # noqa: E501
            [
208
209
210
211
212
213
214
215
216
217
218
219
                ToolCall(
                    function=FunctionCall(
                        name="get_current_weather",
                        arguments=json.dumps(
                            {
                                "city": "Dallas",
                                "state": "TX",
                                "unit": "fahrenheit",
                            }
                        ),
                    )
                )
220
221
222
            ],
            "I'll help you check the weather.",
        ),
Zuxin's avatar
Zuxin committed
223
224
    ],
)
225
226
227
def test_extract_tool_calls(
    xlam_tool_parser, model_output, expected_tool_calls, expected_content
):
Zuxin's avatar
Zuxin committed
228
    extracted_tool_calls = xlam_tool_parser.extract_tool_calls(
229
230
        model_output, request=None
    )  # type: ignore[arg-type]
Zuxin's avatar
Zuxin committed
231
232
233
234
235
236
237
238
239
240
241
242
243
244
    assert extracted_tool_calls.tools_called

    assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)

    assert extracted_tool_calls.content == expected_content


@pytest.mark.parametrize(
    ids=["list_structured_tool_call"],
    argnames=["model_output", "expected_tool_calls", "expected_content"],
    argvalues=[
        (
            """[{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}]""",  # noqa: E501
            [
245
246
247
248
249
250
251
252
253
254
255
256
                ToolCall(
                    function=FunctionCall(
                        name="get_current_weather",
                        arguments=json.dumps(
                            {
                                "city": "Seattle",
                                "state": "WA",
                                "unit": "celsius",
                            }
                        ),
                    )
                )
Zuxin's avatar
Zuxin committed
257
258
259
260
261
            ],
            None,
        ),
    ],
)
262
263
264
def test_extract_tool_calls_list_structure(
    xlam_tool_parser, model_output, expected_tool_calls, expected_content
):
Zuxin's avatar
Zuxin committed
265
266
    """Test extraction of tool calls when the model outputs a list-structured tool call."""  # noqa: E501
    extracted_tool_calls = xlam_tool_parser.extract_tool_calls(
267
268
        model_output, request=None
    )  # type: ignore[arg-type]
Zuxin's avatar
Zuxin committed
269
270
271
272
273
274
275
276
277
278
    assert extracted_tool_calls.tools_called

    assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)

    assert extracted_tool_calls.content == expected_content


# Test for preprocess_model_output method
def test_preprocess_model_output(xlam_tool_parser):
    # Test with list structure
279
280
281
    model_output = (
        """[{"name": "get_current_weather", "arguments": {"city": "Seattle"}}]"""  # noqa: E501
    )
Zuxin's avatar
Zuxin committed
282
    content, potential_tool_calls = xlam_tool_parser.preprocess_model_output(
283
284
        model_output
    )
Zuxin's avatar
Zuxin committed
285
286
287
288
289
290
    assert content is None
    assert potential_tool_calls == model_output

    # Test with thinking tag
    model_output = """<think>I'll help you with that.</think>[{"name": "get_current_weather", "arguments": {"city": "Seattle"}}]"""  # noqa: E501
    content, potential_tool_calls = xlam_tool_parser.preprocess_model_output(
291
292
        model_output
    )
Zuxin's avatar
Zuxin committed
293
294
    assert content == "<think>I'll help you with that.</think>"
    assert (
295
296
297
        potential_tool_calls
        == '[{"name": "get_current_weather", "arguments": {"city": "Seattle"}}]'
    )
Zuxin's avatar
Zuxin committed
298
299
300
301
302
303
304

    # Test with JSON code block
    model_output = """I'll help you with that.
```json
[{"name": "get_current_weather", "arguments": {"city": "Seattle"}}]
```"""
    content, potential_tool_calls = xlam_tool_parser.preprocess_model_output(
305
306
        model_output
    )
Zuxin's avatar
Zuxin committed
307
308
309
310
311
312
    assert content == "I'll help you with that."
    assert "get_current_weather" in potential_tool_calls

    # Test with no tool calls
    model_output = """I'll help you with that."""
    content, potential_tool_calls = xlam_tool_parser.preprocess_model_output(
313
314
        model_output
    )
Zuxin's avatar
Zuxin committed
315
316
317
318
319
320
321
322
323
324
325
326
327
    assert content == model_output
    assert potential_tool_calls is None


# Simulate streaming to test extract_tool_calls_streaming
def test_streaming_with_list_structure(xlam_tool_parser):
    # Reset streaming state
    xlam_tool_parser.prev_tool_calls = []
    xlam_tool_parser.current_tools_sent = []
    xlam_tool_parser.streamed_args = []
    xlam_tool_parser.current_tool_id = -1

    # Simulate receiving a message with list structure
328
329
330
    current_text = (
        """[{"name": "get_current_weather", "arguments": {"city": "Seattle"}}]"""  # noqa: E501
    )
Zuxin's avatar
Zuxin committed
331
332
333
334
335
336
337
338
339
340
341
342
343

    # First call to set up the tool
    xlam_tool_parser.extract_tool_calls_streaming(
        previous_text="",
        current_text=current_text,
        delta_text="]",
        previous_token_ids=[],
        current_token_ids=[],
        delta_token_ids=[],
        request=None,
    )

    # Make sure the tool is set up correctly
344
    assert xlam_tool_parser.current_tool_id >= 0, "Tool index should be initialized"
Zuxin's avatar
Zuxin committed
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364

    # Manually set up the state for sending the tool name
    xlam_tool_parser.current_tools_sent = [False]

    # Call to send the function name
    result = xlam_tool_parser.extract_tool_calls_streaming(
        previous_text=current_text,
        current_text=current_text,
        delta_text="",
        previous_token_ids=[],
        current_token_ids=[],
        delta_token_ids=[],
        request=None,
    )

    # Check that we get a result with the proper tool call
    if result is not None:
        assert hasattr(result, "tool_calls")
        assert len(result.tool_calls) == 1
        assert result.tool_calls[0].function.name == "get_current_weather"
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379


@pytest.mark.parametrize(
    ids=[
        "parallel_tool_calls",
        "single_tool_with_think_tag",
        "single_tool_with_json_code_block",
        "single_tool_with_tool_calls_tag",
        "single_tool_with_tool_call_xml_tags",
    ],
    argnames=["model_output", "expected_tool_calls", "expected_content"],
    argvalues=[
        (
            """[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}, {"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}]""",  # noqa: E501
            [
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
                ToolCall(
                    function=FunctionCall(
                        name="get_current_weather",
                        arguments=json.dumps(
                            {
                                "city": "Dallas",
                                "state": "TX",
                                "unit": "fahrenheit",
                            }
                        ),
                    )
                ),
                ToolCall(
                    function=FunctionCall(
                        name="get_current_weather",
                        arguments=json.dumps(
                            {
                                "city": "Orlando",
                                "state": "FL",
                                "unit": "fahrenheit",
                            }
                        ),
                    )
                ),
404
405
406
407
408
409
            ],
            "",
        ),
        (
            """<think>I'll help you with that.</think>[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]""",  # noqa: E501
            [
410
411
412
413
414
415
416
417
418
419
420
421
                ToolCall(
                    function=FunctionCall(
                        name="get_current_weather",
                        arguments=json.dumps(
                            {
                                "city": "Dallas",
                                "state": "TX",
                                "unit": "fahrenheit",
                            }
                        ),
                    )
                )
422
423
424
425
426
427
            ],
            "<think>I'll help you with that.</think>",
        ),
        (
            """```json\n[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]\n```""",  # noqa: E501
            [
428
429
430
431
432
433
434
435
436
437
438
439
                ToolCall(
                    function=FunctionCall(
                        name="get_current_weather",
                        arguments=json.dumps(
                            {
                                "city": "Dallas",
                                "state": "TX",
                                "unit": "fahrenheit",
                            }
                        ),
                    )
                )
440
441
442
443
444
445
            ],
            "",
        ),
        (
            """[TOOL_CALLS][{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]""",  # noqa: E501
            [
446
447
448
449
450
451
452
453
454
455
456
457
                ToolCall(
                    function=FunctionCall(
                        name="get_current_weather",
                        arguments=json.dumps(
                            {
                                "city": "Dallas",
                                "state": "TX",
                                "unit": "fahrenheit",
                            }
                        ),
                    )
                )
458
459
460
461
462
463
            ],
            "",
        ),
        (
            """I can help with that.<tool_call>[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]</tool_call>""",  # noqa: E501
            [
464
465
466
467
468
469
470
471
472
473
474
475
                ToolCall(
                    function=FunctionCall(
                        name="get_current_weather",
                        arguments=json.dumps(
                            {
                                "city": "Dallas",
                                "state": "TX",
                                "unit": "fahrenheit",
                            }
                        ),
                    )
                )
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
            ],
            "I can help with that.",
        ),
    ],
)
def test_extract_tool_calls_streaming_incremental(
    xlam_tool_parser,
    xlam_tokenizer,
    model_output,
    expected_tool_calls,
    expected_content,
):
    """Verify the XLAM Parser streaming behavior by verifying each chunk is as expected."""  # noqa: E501
    request = ChatCompletionRequest(model=MODEL, messages=[], tools=[])

    chunks = []
    for delta_message in stream_delta_message_generator(
493
494
        xlam_tool_parser, xlam_tokenizer, model_output, request
    ):
495
496
497
498
499
500
501
502
503
504
505
        chunks.append(delta_message)

    # Should have multiple chunks
    assert len(chunks) >= 3

    # Should have a chunk with tool header (id, name, type) for the first tool call # noqa: E501
    header_found = False
    expected_first_tool = expected_tool_calls[0]
    for chunk in chunks:
        if chunk.tool_calls and chunk.tool_calls[0].id:
            header_found = True
506
507
508
            assert (
                chunk.tool_calls[0].function.name == expected_first_tool.function.name
            )
509
510
511
512
513
514
515
516
517
518
519
            assert chunk.tool_calls[0].type == "function"
            # Arguments may be empty initially or None
            if chunk.tool_calls[0].function.arguments is not None:
                # If present, should be empty string initially
                assert chunk.tool_calls[0].function.arguments == ""
            break
    assert header_found

    # Should have chunks with incremental arguments
    arg_chunks = []
    for chunk in chunks:
520
521
522
523
524
525
526
        if (
            chunk.tool_calls
            and chunk.tool_calls[0].function.arguments
            and chunk.tool_calls[0].function.arguments != ""
            and chunk.tool_calls[0].index
            == 0  # Only collect arguments from the first tool call
        ):
527
528
529
530
531
532
533
534
535
536
            arg_chunks.append(chunk.tool_calls[0].function.arguments)

    # Arguments should be streamed incrementally
    assert len(arg_chunks) > 1

    # Concatenated arguments should form valid JSON for the first tool call
    full_args = "".join(arg_chunks)
    parsed_args = json.loads(full_args)
    expected_args = json.loads(expected_first_tool.function.arguments)
    assert parsed_args == expected_args