test_xlam_tool_parser.py 18.8 KB
Newer Older
Zuxin's avatar
Zuxin committed
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
Zuxin's avatar
Zuxin committed
3
4

import json
5
from collections.abc import Generator
Zuxin's avatar
Zuxin committed
6
7
8

import pytest

9
10
11
12
13
14
from vllm.entrypoints.openai.protocol import (
    ChatCompletionRequest,
    DeltaMessage,
    FunctionCall,
    ToolCall,
)
15
from vllm.entrypoints.openai.tool_parsers.xlam_tool_parser import xLAMToolParser
16
from vllm.tokenizers import TokenizerLike
17
from vllm.tokenizers.detokenizer_utils import detokenize_incrementally
18
from vllm.transformers_utils.tokenizer import get_tokenizer
Zuxin's avatar
Zuxin committed
19

20
21
pytestmark = pytest.mark.cpu_test

Zuxin's avatar
Zuxin committed
22
23
24
25
26
27
28
29
30
31
32
33
34
35
# Use a common model that is likely to be available
MODEL = "Salesforce/Llama-xLAM-2-8B-fc-r"


@pytest.fixture(scope="module")
def xlam_tokenizer():
    return get_tokenizer(tokenizer_name=MODEL)


@pytest.fixture
def xlam_tool_parser(xlam_tokenizer):
    return xLAMToolParser(xlam_tokenizer)


36
37
38
def assert_tool_calls(
    actual_tool_calls: list[ToolCall], expected_tool_calls: list[ToolCall]
):
Zuxin's avatar
Zuxin committed
39
40
    assert len(actual_tool_calls) == len(expected_tool_calls)

41
42
43
    for actual_tool_call, expected_tool_call in zip(
        actual_tool_calls, expected_tool_calls
    ):
Zuxin's avatar
Zuxin committed
44
45
46
47
48
49
50
        assert isinstance(actual_tool_call.id, str)
        assert len(actual_tool_call.id) > 16

        assert actual_tool_call.type == "function"
        assert actual_tool_call.function == expected_tool_call.function


51
52
def stream_delta_message_generator(
    xlam_tool_parser: xLAMToolParser,
53
    xlam_tokenizer: TokenizerLike,
54
    model_output: str,
55
    request: ChatCompletionRequest | None = None,
56
) -> Generator[DeltaMessage, None, None]:
57
    all_token_ids = xlam_tokenizer.encode(model_output, add_special_tokens=False)
58
59
60
61
62
63
64
65

    previous_text = ""
    previous_tokens = None
    prefix_offset = 0
    read_offset = 0
    for i, delta_token in enumerate(all_token_ids):
        delta_token_ids = [delta_token]
        previous_token_ids = all_token_ids[:i]
66
67
68
69
70
71
72
73
74
75
76
77
78
        current_token_ids = all_token_ids[: i + 1]

        (new_tokens, delta_text, new_prefix_offset, new_read_offset) = (
            detokenize_incrementally(
                tokenizer=xlam_tokenizer,
                all_input_ids=current_token_ids,
                prev_tokens=previous_tokens,
                prefix_offset=prefix_offset,
                read_offset=read_offset,
                skip_special_tokens=False,
                spaces_between_special_tokens=True,
            )
        )
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94

        current_text = previous_text + delta_text

        delta_message = xlam_tool_parser.extract_tool_calls_streaming(
            previous_text,
            current_text,
            delta_text,
            previous_token_ids,
            current_token_ids,
            delta_token_ids,
            request=request,
        )
        if delta_message:
            yield delta_message

        previous_text = current_text
95
96
97
        previous_tokens = (
            previous_tokens + new_tokens if previous_tokens else new_tokens
        )
98
99
100
101
        prefix_offset = new_prefix_offset
        read_offset = new_read_offset


Zuxin's avatar
Zuxin committed
102
103
104
def test_extract_tool_calls_no_tools(xlam_tool_parser):
    model_output = "This is a test"
    extracted_tool_calls = xlam_tool_parser.extract_tool_calls(
105
106
        model_output, request=None
    )  # type: ignore[arg-type]
Zuxin's avatar
Zuxin committed
107
108
109
110
111
112
113
114
115
116
117
    assert not extracted_tool_calls.tools_called
    assert extracted_tool_calls.tool_calls == []
    assert extracted_tool_calls.content == model_output


@pytest.mark.parametrize(
    ids=[
        "parallel_tool_calls",
        "single_tool_with_think_tag",
        "single_tool_with_json_code_block",
        "single_tool_with_tool_calls_tag",
118
        "single_tool_with_tool_call_xml_tags",
Zuxin's avatar
Zuxin committed
119
120
121
122
123
124
    ],
    argnames=["model_output", "expected_tool_calls", "expected_content"],
    argvalues=[
        (
            """[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}, {"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}]""",  # noqa: E501
            [
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
                ToolCall(
                    function=FunctionCall(
                        name="get_current_weather",
                        arguments=json.dumps(
                            {
                                "city": "Dallas",
                                "state": "TX",
                                "unit": "fahrenheit",
                            }
                        ),
                    )
                ),
                ToolCall(
                    function=FunctionCall(
                        name="get_current_weather",
                        arguments=json.dumps(
                            {
                                "city": "Orlando",
                                "state": "FL",
                                "unit": "fahrenheit",
                            }
                        ),
                    )
                ),
Zuxin's avatar
Zuxin committed
149
150
151
152
153
154
            ],
            None,
        ),
        (
            """<think>I'll help you with that.</think>[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]""",  # noqa: E501
            [
155
156
157
158
159
160
161
162
163
164
165
166
                ToolCall(
                    function=FunctionCall(
                        name="get_current_weather",
                        arguments=json.dumps(
                            {
                                "city": "Dallas",
                                "state": "TX",
                                "unit": "fahrenheit",
                            }
                        ),
                    )
                )
Zuxin's avatar
Zuxin committed
167
168
169
170
171
172
            ],
            "<think>I'll help you with that.</think>",
        ),
        (
            """I'll help you with that.\n```json\n[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]\n```""",  # noqa: E501
            [
173
174
175
176
177
178
179
180
181
182
183
184
                ToolCall(
                    function=FunctionCall(
                        name="get_current_weather",
                        arguments=json.dumps(
                            {
                                "city": "Dallas",
                                "state": "TX",
                                "unit": "fahrenheit",
                            }
                        ),
                    )
                )
Zuxin's avatar
Zuxin committed
185
186
187
188
189
190
            ],
            "I'll help you with that.",
        ),
        (
            """I'll check the weather for you.[TOOL_CALLS][{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]""",  # noqa: E501
            [
191
192
193
194
195
196
197
198
199
200
201
202
                ToolCall(
                    function=FunctionCall(
                        name="get_current_weather",
                        arguments=json.dumps(
                            {
                                "city": "Dallas",
                                "state": "TX",
                                "unit": "fahrenheit",
                            }
                        ),
                    )
                )
Zuxin's avatar
Zuxin committed
203
204
205
            ],
            "I'll check the weather for you.",
        ),
206
207
208
        (
            """I'll help you check the weather.<tool_call>[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]</tool_call>""",  # noqa: E501
            [
209
210
211
212
213
214
215
216
217
218
219
220
                ToolCall(
                    function=FunctionCall(
                        name="get_current_weather",
                        arguments=json.dumps(
                            {
                                "city": "Dallas",
                                "state": "TX",
                                "unit": "fahrenheit",
                            }
                        ),
                    )
                )
221
222
223
            ],
            "I'll help you check the weather.",
        ),
Zuxin's avatar
Zuxin committed
224
225
    ],
)
226
227
228
def test_extract_tool_calls(
    xlam_tool_parser, model_output, expected_tool_calls, expected_content
):
Zuxin's avatar
Zuxin committed
229
    extracted_tool_calls = xlam_tool_parser.extract_tool_calls(
230
231
        model_output, request=None
    )  # type: ignore[arg-type]
Zuxin's avatar
Zuxin committed
232
233
234
235
236
237
238
239
240
241
242
243
244
245
    assert extracted_tool_calls.tools_called

    assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)

    assert extracted_tool_calls.content == expected_content


@pytest.mark.parametrize(
    ids=["list_structured_tool_call"],
    argnames=["model_output", "expected_tool_calls", "expected_content"],
    argvalues=[
        (
            """[{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}]""",  # noqa: E501
            [
246
247
248
249
250
251
252
253
254
255
256
257
                ToolCall(
                    function=FunctionCall(
                        name="get_current_weather",
                        arguments=json.dumps(
                            {
                                "city": "Seattle",
                                "state": "WA",
                                "unit": "celsius",
                            }
                        ),
                    )
                )
Zuxin's avatar
Zuxin committed
258
259
260
261
262
            ],
            None,
        ),
    ],
)
263
264
265
def test_extract_tool_calls_list_structure(
    xlam_tool_parser, model_output, expected_tool_calls, expected_content
):
Zuxin's avatar
Zuxin committed
266
267
    """Test extraction of tool calls when the model outputs a list-structured tool call."""  # noqa: E501
    extracted_tool_calls = xlam_tool_parser.extract_tool_calls(
268
269
        model_output, request=None
    )  # type: ignore[arg-type]
Zuxin's avatar
Zuxin committed
270
271
272
273
274
275
276
277
278
279
    assert extracted_tool_calls.tools_called

    assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)

    assert extracted_tool_calls.content == expected_content


# Test for preprocess_model_output method
def test_preprocess_model_output(xlam_tool_parser):
    # Test with list structure
280
281
282
    model_output = (
        """[{"name": "get_current_weather", "arguments": {"city": "Seattle"}}]"""  # noqa: E501
    )
Zuxin's avatar
Zuxin committed
283
    content, potential_tool_calls = xlam_tool_parser.preprocess_model_output(
284
285
        model_output
    )
Zuxin's avatar
Zuxin committed
286
287
288
289
290
291
    assert content is None
    assert potential_tool_calls == model_output

    # Test with thinking tag
    model_output = """<think>I'll help you with that.</think>[{"name": "get_current_weather", "arguments": {"city": "Seattle"}}]"""  # noqa: E501
    content, potential_tool_calls = xlam_tool_parser.preprocess_model_output(
292
293
        model_output
    )
Zuxin's avatar
Zuxin committed
294
295
    assert content == "<think>I'll help you with that.</think>"
    assert (
296
297
298
        potential_tool_calls
        == '[{"name": "get_current_weather", "arguments": {"city": "Seattle"}}]'
    )
Zuxin's avatar
Zuxin committed
299
300
301
302
303
304
305

    # Test with JSON code block
    model_output = """I'll help you with that.
```json
[{"name": "get_current_weather", "arguments": {"city": "Seattle"}}]
```"""
    content, potential_tool_calls = xlam_tool_parser.preprocess_model_output(
306
307
        model_output
    )
Zuxin's avatar
Zuxin committed
308
309
310
311
312
313
    assert content == "I'll help you with that."
    assert "get_current_weather" in potential_tool_calls

    # Test with no tool calls
    model_output = """I'll help you with that."""
    content, potential_tool_calls = xlam_tool_parser.preprocess_model_output(
314
315
        model_output
    )
Zuxin's avatar
Zuxin committed
316
317
318
319
320
321
322
323
324
325
326
327
328
    assert content == model_output
    assert potential_tool_calls is None


# Simulate streaming to test extract_tool_calls_streaming
def test_streaming_with_list_structure(xlam_tool_parser):
    # Reset streaming state
    xlam_tool_parser.prev_tool_calls = []
    xlam_tool_parser.current_tools_sent = []
    xlam_tool_parser.streamed_args = []
    xlam_tool_parser.current_tool_id = -1

    # Simulate receiving a message with list structure
329
330
331
    current_text = (
        """[{"name": "get_current_weather", "arguments": {"city": "Seattle"}}]"""  # noqa: E501
    )
Zuxin's avatar
Zuxin committed
332
333
334
335
336
337
338
339
340
341
342
343
344

    # First call to set up the tool
    xlam_tool_parser.extract_tool_calls_streaming(
        previous_text="",
        current_text=current_text,
        delta_text="]",
        previous_token_ids=[],
        current_token_ids=[],
        delta_token_ids=[],
        request=None,
    )

    # Make sure the tool is set up correctly
345
    assert xlam_tool_parser.current_tool_id >= 0, "Tool index should be initialized"
Zuxin's avatar
Zuxin committed
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365

    # Manually set up the state for sending the tool name
    xlam_tool_parser.current_tools_sent = [False]

    # Call to send the function name
    result = xlam_tool_parser.extract_tool_calls_streaming(
        previous_text=current_text,
        current_text=current_text,
        delta_text="",
        previous_token_ids=[],
        current_token_ids=[],
        delta_token_ids=[],
        request=None,
    )

    # Check that we get a result with the proper tool call
    if result is not None:
        assert hasattr(result, "tool_calls")
        assert len(result.tool_calls) == 1
        assert result.tool_calls[0].function.name == "get_current_weather"
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380


@pytest.mark.parametrize(
    ids=[
        "parallel_tool_calls",
        "single_tool_with_think_tag",
        "single_tool_with_json_code_block",
        "single_tool_with_tool_calls_tag",
        "single_tool_with_tool_call_xml_tags",
    ],
    argnames=["model_output", "expected_tool_calls", "expected_content"],
    argvalues=[
        (
            """[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}, {"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}]""",  # noqa: E501
            [
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
                ToolCall(
                    function=FunctionCall(
                        name="get_current_weather",
                        arguments=json.dumps(
                            {
                                "city": "Dallas",
                                "state": "TX",
                                "unit": "fahrenheit",
                            }
                        ),
                    )
                ),
                ToolCall(
                    function=FunctionCall(
                        name="get_current_weather",
                        arguments=json.dumps(
                            {
                                "city": "Orlando",
                                "state": "FL",
                                "unit": "fahrenheit",
                            }
                        ),
                    )
                ),
405
406
407
408
409
410
            ],
            "",
        ),
        (
            """<think>I'll help you with that.</think>[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]""",  # noqa: E501
            [
411
412
413
414
415
416
417
418
419
420
421
422
                ToolCall(
                    function=FunctionCall(
                        name="get_current_weather",
                        arguments=json.dumps(
                            {
                                "city": "Dallas",
                                "state": "TX",
                                "unit": "fahrenheit",
                            }
                        ),
                    )
                )
423
424
425
426
427
428
            ],
            "<think>I'll help you with that.</think>",
        ),
        (
            """```json\n[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]\n```""",  # noqa: E501
            [
429
430
431
432
433
434
435
436
437
438
439
440
                ToolCall(
                    function=FunctionCall(
                        name="get_current_weather",
                        arguments=json.dumps(
                            {
                                "city": "Dallas",
                                "state": "TX",
                                "unit": "fahrenheit",
                            }
                        ),
                    )
                )
441
442
443
444
445
446
            ],
            "",
        ),
        (
            """[TOOL_CALLS][{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]""",  # noqa: E501
            [
447
448
449
450
451
452
453
454
455
456
457
458
                ToolCall(
                    function=FunctionCall(
                        name="get_current_weather",
                        arguments=json.dumps(
                            {
                                "city": "Dallas",
                                "state": "TX",
                                "unit": "fahrenheit",
                            }
                        ),
                    )
                )
459
460
461
462
463
464
            ],
            "",
        ),
        (
            """I can help with that.<tool_call>[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]</tool_call>""",  # noqa: E501
            [
465
466
467
468
469
470
471
472
473
474
475
476
                ToolCall(
                    function=FunctionCall(
                        name="get_current_weather",
                        arguments=json.dumps(
                            {
                                "city": "Dallas",
                                "state": "TX",
                                "unit": "fahrenheit",
                            }
                        ),
                    )
                )
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
            ],
            "I can help with that.",
        ),
    ],
)
def test_extract_tool_calls_streaming_incremental(
    xlam_tool_parser,
    xlam_tokenizer,
    model_output,
    expected_tool_calls,
    expected_content,
):
    """Verify the XLAM Parser streaming behavior by verifying each chunk is as expected."""  # noqa: E501
    request = ChatCompletionRequest(model=MODEL, messages=[], tools=[])

    chunks = []
    for delta_message in stream_delta_message_generator(
494
495
        xlam_tool_parser, xlam_tokenizer, model_output, request
    ):
496
497
498
499
500
501
502
503
504
505
506
        chunks.append(delta_message)

    # Should have multiple chunks
    assert len(chunks) >= 3

    # Should have a chunk with tool header (id, name, type) for the first tool call # noqa: E501
    header_found = False
    expected_first_tool = expected_tool_calls[0]
    for chunk in chunks:
        if chunk.tool_calls and chunk.tool_calls[0].id:
            header_found = True
507
508
509
            assert (
                chunk.tool_calls[0].function.name == expected_first_tool.function.name
            )
510
511
512
513
514
515
516
517
518
519
520
            assert chunk.tool_calls[0].type == "function"
            # Arguments may be empty initially or None
            if chunk.tool_calls[0].function.arguments is not None:
                # If present, should be empty string initially
                assert chunk.tool_calls[0].function.arguments == ""
            break
    assert header_found

    # Should have chunks with incremental arguments
    arg_chunks = []
    for chunk in chunks:
521
522
523
524
525
526
527
        if (
            chunk.tool_calls
            and chunk.tool_calls[0].function.arguments
            and chunk.tool_calls[0].function.arguments != ""
            and chunk.tool_calls[0].index
            == 0  # Only collect arguments from the first tool call
        ):
528
529
530
531
532
533
534
535
536
537
            arg_chunks.append(chunk.tool_calls[0].function.arguments)

    # Arguments should be streamed incrementally
    assert len(arg_chunks) > 1

    # Concatenated arguments should form valid JSON for the first tool call
    full_args = "".join(arg_chunks)
    parsed_args = json.loads(full_args)
    expected_args = json.loads(expected_first_tool.function.arguments)
    assert parsed_args == expected_args