test_xlam_tool_parser.py 17.2 KB
Newer Older
Zuxin's avatar
Zuxin committed
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
Zuxin's avatar
Zuxin committed
3
4

import json
5
6
from collections.abc import Generator
from typing import Optional
Zuxin's avatar
Zuxin committed
7
8
9

import pytest

10
11
12
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
                                              DeltaMessage, FunctionCall,
                                              ToolCall)
Zuxin's avatar
Zuxin committed
13
from vllm.entrypoints.openai.tool_parsers import xLAMToolParser
14
from vllm.transformers_utils.detokenizer_utils import detokenize_incrementally
15
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
Zuxin's avatar
Zuxin committed
16

17
18
pytestmark = pytest.mark.cpu_test

Zuxin's avatar
Zuxin committed
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
# Use a common model that is likely to be available
MODEL = "Salesforce/Llama-xLAM-2-8B-fc-r"


@pytest.fixture(scope="module")
def xlam_tokenizer():
    return get_tokenizer(tokenizer_name=MODEL)


@pytest.fixture
def xlam_tool_parser(xlam_tokenizer):
    return xLAMToolParser(xlam_tokenizer)


def assert_tool_calls(actual_tool_calls: list[ToolCall],
                      expected_tool_calls: list[ToolCall]):
    assert len(actual_tool_calls) == len(expected_tool_calls)

    for actual_tool_call, expected_tool_call in zip(actual_tool_calls,
                                                    expected_tool_calls):
        assert isinstance(actual_tool_call.id, str)
        assert len(actual_tool_call.id) > 16

        assert actual_tool_call.type == "function"
        assert actual_tool_call.function == expected_tool_call.function


46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
def stream_delta_message_generator(
    xlam_tool_parser: xLAMToolParser,
    xlam_tokenizer: AnyTokenizer,
    model_output: str,
    request: Optional[ChatCompletionRequest] = None,
) -> Generator[DeltaMessage, None, None]:
    all_token_ids = xlam_tokenizer.encode(model_output,
                                          add_special_tokens=False)

    previous_text = ""
    previous_tokens = None
    prefix_offset = 0
    read_offset = 0
    for i, delta_token in enumerate(all_token_ids):
        delta_token_ids = [delta_token]
        previous_token_ids = all_token_ids[:i]
        current_token_ids = all_token_ids[:i + 1]

        (new_tokens, delta_text, new_prefix_offset,
         new_read_offset) = (detokenize_incrementally(
             tokenizer=xlam_tokenizer,
             all_input_ids=current_token_ids,
             prev_tokens=previous_tokens,
             prefix_offset=prefix_offset,
             read_offset=read_offset,
             skip_special_tokens=False,
             spaces_between_special_tokens=True,
         ))

        current_text = previous_text + delta_text

        delta_message = xlam_tool_parser.extract_tool_calls_streaming(
            previous_text,
            current_text,
            delta_text,
            previous_token_ids,
            current_token_ids,
            delta_token_ids,
            request=request,
        )
        if delta_message:
            yield delta_message

        previous_text = current_text
        previous_tokens = (previous_tokens +
                           new_tokens if previous_tokens else new_tokens)
        prefix_offset = new_prefix_offset
        read_offset = new_read_offset


Zuxin's avatar
Zuxin committed
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
def test_extract_tool_calls_no_tools(xlam_tool_parser):
    model_output = "This is a test"
    extracted_tool_calls = xlam_tool_parser.extract_tool_calls(
        model_output, request=None)  # type: ignore[arg-type]
    assert not extracted_tool_calls.tools_called
    assert extracted_tool_calls.tool_calls == []
    assert extracted_tool_calls.content == model_output


@pytest.mark.parametrize(
    ids=[
        "parallel_tool_calls",
        "single_tool_with_think_tag",
        "single_tool_with_json_code_block",
        "single_tool_with_tool_calls_tag",
111
        "single_tool_with_tool_call_xml_tags",
Zuxin's avatar
Zuxin committed
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
    ],
    argnames=["model_output", "expected_tool_calls", "expected_content"],
    argvalues=[
        (
            """[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}, {"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}]""",  # noqa: E501
            [
                ToolCall(function=FunctionCall(
                    name="get_current_weather",
                    arguments=json.dumps({
                        "city": "Dallas",
                        "state": "TX",
                        "unit": "fahrenheit",
                    }),
                )),
                ToolCall(function=FunctionCall(
                    name="get_current_weather",
                    arguments=json.dumps({
                        "city": "Orlando",
                        "state": "FL",
                        "unit": "fahrenheit",
                    }),
                )),
            ],
            None,
        ),
        (
            """<think>I'll help you with that.</think>[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]""",  # noqa: E501
            [
                ToolCall(function=FunctionCall(
                    name="get_current_weather",
                    arguments=json.dumps({
                        "city": "Dallas",
                        "state": "TX",
                        "unit": "fahrenheit",
                    }),
                ))
            ],
            "<think>I'll help you with that.</think>",
        ),
        (
            """I'll help you with that.\n```json\n[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]\n```""",  # noqa: E501
            [
                ToolCall(function=FunctionCall(
                    name="get_current_weather",
                    arguments=json.dumps({
                        "city": "Dallas",
                        "state": "TX",
                        "unit": "fahrenheit",
                    }),
                ))
            ],
            "I'll help you with that.",
        ),
        (
            """I'll check the weather for you.[TOOL_CALLS][{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]""",  # noqa: E501
            [
                ToolCall(function=FunctionCall(
                    name="get_current_weather",
                    arguments=json.dumps({
                        "city": "Dallas",
                        "state": "TX",
                        "unit": "fahrenheit",
                    }),
                ))
            ],
            "I'll check the weather for you.",
        ),
179
180
181
182
183
184
185
186
187
188
189
190
191
192
        (
            """I'll help you check the weather.<tool_call>[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]</tool_call>""",  # noqa: E501
            [
                ToolCall(function=FunctionCall(
                    name="get_current_weather",
                    arguments=json.dumps({
                        "city": "Dallas",
                        "state": "TX",
                        "unit": "fahrenheit",
                    }),
                ))
            ],
            "I'll help you check the weather.",
        ),
Zuxin's avatar
Zuxin committed
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
    ],
)
def test_extract_tool_calls(xlam_tool_parser, model_output,
                            expected_tool_calls, expected_content):
    extracted_tool_calls = xlam_tool_parser.extract_tool_calls(
        model_output, request=None)  # type: ignore[arg-type]
    assert extracted_tool_calls.tools_called

    assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)

    assert extracted_tool_calls.content == expected_content


@pytest.mark.parametrize(
    ids=["list_structured_tool_call"],
    argnames=["model_output", "expected_tool_calls", "expected_content"],
    argvalues=[
        (
            """[{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}]""",  # noqa: E501
            [
                ToolCall(function=FunctionCall(
                    name="get_current_weather",
                    arguments=json.dumps({
                        "city": "Seattle",
                        "state": "WA",
                        "unit": "celsius",
                    }),
                ))
            ],
            None,
        ),
    ],
)
def test_extract_tool_calls_list_structure(xlam_tool_parser, model_output,
                                           expected_tool_calls,
                                           expected_content):
    """Test extraction of tool calls when the model outputs a list-structured tool call."""  # noqa: E501
    extracted_tool_calls = xlam_tool_parser.extract_tool_calls(
        model_output, request=None)  # type: ignore[arg-type]
    assert extracted_tool_calls.tools_called

    assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)

    assert extracted_tool_calls.content == expected_content


# Test for preprocess_model_output method
def test_preprocess_model_output(xlam_tool_parser):
    # Test with list structure
    model_output = """[{"name": "get_current_weather", "arguments": {"city": "Seattle"}}]"""  # noqa: E501
    content, potential_tool_calls = xlam_tool_parser.preprocess_model_output(
        model_output)
    assert content is None
    assert potential_tool_calls == model_output

    # Test with thinking tag
    model_output = """<think>I'll help you with that.</think>[{"name": "get_current_weather", "arguments": {"city": "Seattle"}}]"""  # noqa: E501
    content, potential_tool_calls = xlam_tool_parser.preprocess_model_output(
        model_output)
    assert content == "<think>I'll help you with that.</think>"
    assert (
        potential_tool_calls ==
        '[{"name": "get_current_weather", "arguments": {"city": "Seattle"}}]')

    # Test with JSON code block
    model_output = """I'll help you with that.
```json
[{"name": "get_current_weather", "arguments": {"city": "Seattle"}}]
```"""
    content, potential_tool_calls = xlam_tool_parser.preprocess_model_output(
        model_output)
    assert content == "I'll help you with that."
    assert "get_current_weather" in potential_tool_calls

    # Test with no tool calls
    model_output = """I'll help you with that."""
    content, potential_tool_calls = xlam_tool_parser.preprocess_model_output(
        model_output)
    assert content == model_output
    assert potential_tool_calls is None


# Simulate streaming to test extract_tool_calls_streaming
def test_streaming_with_list_structure(xlam_tool_parser):
    # Reset streaming state
    xlam_tool_parser.prev_tool_calls = []
    xlam_tool_parser.current_tools_sent = []
    xlam_tool_parser.streamed_args = []
    xlam_tool_parser.current_tool_id = -1

    # Simulate receiving a message with list structure
    current_text = """[{"name": "get_current_weather", "arguments": {"city": "Seattle"}}]"""  # noqa: E501

    # First call to set up the tool
    xlam_tool_parser.extract_tool_calls_streaming(
        previous_text="",
        current_text=current_text,
        delta_text="]",
        previous_token_ids=[],
        current_token_ids=[],
        delta_token_ids=[],
        request=None,
    )

    # Make sure the tool is set up correctly
    assert (xlam_tool_parser.current_tool_id
            >= 0), "Tool index should be initialized"

    # Manually set up the state for sending the tool name
    xlam_tool_parser.current_tools_sent = [False]

    # Call to send the function name
    result = xlam_tool_parser.extract_tool_calls_streaming(
        previous_text=current_text,
        current_text=current_text,
        delta_text="",
        previous_token_ids=[],
        current_token_ids=[],
        delta_token_ids=[],
        request=None,
    )

    # Check that we get a result with the proper tool call
    if result is not None:
        assert hasattr(result, "tool_calls")
        assert len(result.tool_calls) == 1
        assert result.tool_calls[0].function.name == "get_current_weather"
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463


@pytest.mark.parametrize(
    ids=[
        "parallel_tool_calls",
        "single_tool_with_think_tag",
        "single_tool_with_json_code_block",
        "single_tool_with_tool_calls_tag",
        "single_tool_with_tool_call_xml_tags",
    ],
    argnames=["model_output", "expected_tool_calls", "expected_content"],
    argvalues=[
        (
            """[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}, {"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}]""",  # noqa: E501
            [
                ToolCall(function=FunctionCall(
                    name="get_current_weather",
                    arguments=json.dumps({
                        "city": "Dallas",
                        "state": "TX",
                        "unit": "fahrenheit",
                    }),
                )),
                ToolCall(function=FunctionCall(
                    name="get_current_weather",
                    arguments=json.dumps({
                        "city": "Orlando",
                        "state": "FL",
                        "unit": "fahrenheit",
                    }),
                )),
            ],
            "",
        ),
        (
            """<think>I'll help you with that.</think>[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]""",  # noqa: E501
            [
                ToolCall(function=FunctionCall(
                    name="get_current_weather",
                    arguments=json.dumps({
                        "city": "Dallas",
                        "state": "TX",
                        "unit": "fahrenheit",
                    }),
                ))
            ],
            "<think>I'll help you with that.</think>",
        ),
        (
            """```json\n[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]\n```""",  # noqa: E501
            [
                ToolCall(function=FunctionCall(
                    name="get_current_weather",
                    arguments=json.dumps({
                        "city": "Dallas",
                        "state": "TX",
                        "unit": "fahrenheit",
                    }),
                ))
            ],
            "",
        ),
        (
            """[TOOL_CALLS][{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]""",  # noqa: E501
            [
                ToolCall(function=FunctionCall(
                    name="get_current_weather",
                    arguments=json.dumps({
                        "city": "Dallas",
                        "state": "TX",
                        "unit": "fahrenheit",
                    }),
                ))
            ],
            "",
        ),
        (
            """I can help with that.<tool_call>[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]</tool_call>""",  # noqa: E501
            [
                ToolCall(function=FunctionCall(
                    name="get_current_weather",
                    arguments=json.dumps({
                        "city": "Dallas",
                        "state": "TX",
                        "unit": "fahrenheit",
                    }),
                ))
            ],
            "I can help with that.",
        ),
    ],
)
def test_extract_tool_calls_streaming_incremental(
    xlam_tool_parser,
    xlam_tokenizer,
    model_output,
    expected_tool_calls,
    expected_content,
):
    """Verify the XLAM Parser streaming behavior by verifying each chunk is as expected."""  # noqa: E501
    request = ChatCompletionRequest(model=MODEL, messages=[], tools=[])

    chunks = []
    for delta_message in stream_delta_message_generator(
            xlam_tool_parser, xlam_tokenizer, model_output, request):
        chunks.append(delta_message)

    # Should have multiple chunks
    assert len(chunks) >= 3

    # Should have a chunk with tool header (id, name, type) for the first tool call # noqa: E501
    header_found = False
    expected_first_tool = expected_tool_calls[0]
    for chunk in chunks:
        if chunk.tool_calls and chunk.tool_calls[0].id:
            header_found = True
            assert (chunk.tool_calls[0].function.name ==
                    expected_first_tool.function.name)
            assert chunk.tool_calls[0].type == "function"
            # Arguments may be empty initially or None
            if chunk.tool_calls[0].function.arguments is not None:
                # If present, should be empty string initially
                assert chunk.tool_calls[0].function.arguments == ""
            break
    assert header_found

    # Should have chunks with incremental arguments
    arg_chunks = []
    for chunk in chunks:
        if (chunk.tool_calls and chunk.tool_calls[0].function.arguments
                and chunk.tool_calls[0].function.arguments != ""
                and chunk.tool_calls[0].index ==
                0  # Only collect arguments from the first tool call
            ):
            arg_chunks.append(chunk.tool_calls[0].function.arguments)

    # Arguments should be streamed incrementally
    assert len(arg_chunks) > 1

    # Concatenated arguments should form valid JSON for the first tool call
    full_args = "".join(arg_chunks)
    parsed_args = json.loads(full_args)
    expected_args = json.loads(expected_first_tool.function.arguments)
    assert parsed_args == expected_args