qwen3coder_tool_parser.py 27.7 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import ast
4
5
6
import json
import uuid
from collections.abc import Sequence
7
from typing import Any
8
9
10

import regex as re

11
from vllm.entrypoints.openai.chat_completion.protocol import (
12
    ChatCompletionRequest,
13
14
)
from vllm.entrypoints.openai.engine.protocol import (
15
16
17
18
19
20
21
    DeltaFunctionCall,
    DeltaMessage,
    DeltaToolCall,
    ExtractedToolCallInformation,
    FunctionCall,
    ToolCall,
)
22
from vllm.logger import init_logger
23
from vllm.tokenizers import TokenizerLike
24
from vllm.tool_parsers.abstract_tool_parser import (
25
    Tool,
26
27
    ToolParser,
)
28
from vllm.tool_parsers.utils import find_tool_properties
29
30
31
32
33

logger = init_logger(__name__)


class Qwen3CoderToolParser(ToolParser):
34
35
    def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
        super().__init__(tokenizer, tools)
36
37
38

        self.current_tool_name_sent: bool = False
        self.prev_tool_call_arr: list[dict] = []
39
        # Override base class type - we use string IDs for tool calls
40
        self.current_tool_id: str | None = None  # type: ignore
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
        self.streamed_args_for_tool: list[str] = []

        # Sentinel tokens for streaming mode
        self.tool_call_start_token: str = "<tool_call>"
        self.tool_call_end_token: str = "</tool_call>"
        self.tool_call_prefix: str = "<function="
        self.function_end_token: str = "</function>"
        self.parameter_prefix: str = "<parameter="
        self.parameter_end_token: str = "</parameter>"
        self.is_tool_call_started: bool = False
        self.failed_count: int = 0

        # Enhanced streaming state - reset for each new message
        self._reset_streaming_state()

        # Regex patterns
        self.tool_call_complete_regex = re.compile(
58
59
            r"<tool_call>(.*?)</tool_call>", re.DOTALL
        )
60
        self.tool_call_regex = re.compile(
61
62
            r"<tool_call>(.*?)</tool_call>|<tool_call>(.*?)$", re.DOTALL
        )
63
        self.tool_call_function_regex = re.compile(
64
65
            r"<function=(.*?)</function>|<function=(.*)$", re.DOTALL
        )
66
        self.tool_call_parameter_regex = re.compile(
67
            r"<parameter=(.*?)(?:</parameter>|(?=<parameter=)|(?=</function>)|$)",
68
69
            re.DOTALL,
        )
70
71
72
73

        if not self.model_tokenizer:
            raise ValueError(
                "The model tokenizer must be passed to the ToolParser "
74
75
                "constructor during construction."
            )
76

77
        self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token)
78
79
        self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)

80
        if self.tool_call_start_token_id is None or self.tool_call_end_token_id is None:
81
82
            raise RuntimeError(
                "Qwen3 XML Tool parser could not locate tool call start/end "
83
84
                "tokens in the tokenizer!"
            )
85

86
        logger.debug(
87
88
            "vLLM Successfully import tool parser %s !", self.__class__.__name__
        )
89
90
91
92
93
94
95
96
97
98

    def _generate_tool_call_id(self) -> str:
        """Generate a unique tool call ID."""
        return f"call_{uuid.uuid4().hex[:24]}"

    def _reset_streaming_state(self):
        """Reset all streaming state."""
        self.current_tool_index = 0
        self.is_tool_call_started = False
        self.header_sent = False
99
        self.current_tool_id = None
100
101
102
103
104
105
106
107
108
        self.current_function_name = None
        self.current_param_name = None
        self.current_param_value = ""
        self.param_count = 0
        self.in_param = False
        self.in_function = False
        self.accumulated_text = ""
        self.json_started = False
        self.json_closed = False
109
110
111
112
        # Store accumulated parameters for type conversion
        self.accumulated_params = {}
        self.streaming_request = None

113
114
115
    def _convert_param_value(
        self, param_value: str, param_name: str, param_config: dict, func_name: str
    ) -> Any:
116
        """Convert parameter value based on its type in the schema."""
117
        # Handle null value for any type
118
119
        if param_value.lower() == "null":
            return None
120

121
122
        if param_name not in param_config:
            if param_config != {}:
123
                logger.debug(
124
125
                    "Parsed parameter '%s' is not defined in the tool "
                    "parameters for tool '%s', directly returning the "
126
127
128
129
                    "string value.",
                    param_name,
                    func_name,
                )
130
131
            return param_value

132
133
134
135
136
137
138
139
140
141
142
143
144
        if (
            isinstance(param_config[param_name], dict)
            and "type" in param_config[param_name]
        ):
            param_type = str(param_config[param_name]["type"]).strip().lower()
        elif (
            isinstance(param_config[param_name], dict)
            and "anyOf" in param_config[param_name]
        ):
            # anyOf has no top-level "type"; treat as object to trigger json.loads.
            param_type = "object"
        else:
            param_type = "string"
145
146
        if param_type in ["string", "str", "text", "varchar", "char", "enum"]:
            return param_value
147
148
149
150
151
152
153
        elif (
            param_type.startswith("int")
            or param_type.startswith("uint")
            or param_type.startswith("long")
            or param_type.startswith("short")
            or param_type.startswith("unsigned")
        ):
154
155
156
            try:
                return int(param_value)
            except (ValueError, TypeError):
157
                logger.debug(
158
159
                    "Parsed value '%s' of parameter '%s' is not an "
                    "integer in tool '%s', degenerating to string.",
160
161
162
163
                    param_value,
                    param_name,
                    func_name,
                )
164
                return param_value
165
166
167
        elif param_type.startswith("num") or param_type.startswith("float"):
            try:
                float_param_value = float(param_value)
168
169
170
171
172
                return (
                    float_param_value
                    if float_param_value - int(float_param_value) != 0
                    else int(float_param_value)
                )
173
            except (ValueError, TypeError):
174
                logger.debug(
175
                    "Parsed value '%s' of parameter '%s' is not a float "
176
177
178
179
180
                    "in tool '%s', degenerating to string.",
                    param_value,
                    param_name,
                    func_name,
                )
181
                return param_value
182
183
184
        elif param_type in ["boolean", "bool", "binary"]:
            param_value = param_value.lower()
            if param_value not in ["true", "false"]:
185
                logger.debug(
186
187
                    "Parsed value '%s' of parameter '%s' is not a boolean "
                    "(`true` or `false`) in tool '%s', degenerating to "
188
189
190
191
192
                    "false.",
                    param_value,
                    param_name,
                    func_name,
                )
193
194
            return param_value == "true"
        else:
195
196
197
198
199
            if (
                param_type in ["object", "array", "arr"]
                or param_type.startswith("dict")
                or param_type.startswith("list")
            ):
200
                try:
201
202
203
                    param_value = json.loads(param_value)
                    return param_value
                except (json.JSONDecodeError, TypeError, ValueError):
204
                    logger.debug(
205
206
                        "Parsed value '%s' of parameter '%s' cannot be "
                        "parsed with json.loads in tool '%s', will try "
207
208
209
210
211
                        "other methods to parse it.",
                        param_value,
                        param_name,
                        func_name,
                    )
212
213
214
            try:
                param_value = ast.literal_eval(param_value)  # safer
            except (ValueError, SyntaxError, TypeError):
215
                logger.debug(
216
217
                    "Parsed value '%s' of parameter '%s' cannot be "
                    "converted via Python `ast.literal_eval()` in tool "
218
219
220
221
222
                    "'%s', degenerating to string.",
                    param_value,
                    param_name,
                    func_name,
                )
223
224
            return param_value

225
    def _parse_xml_function_call(self, function_call_str: str) -> ToolCall | None:
226
        # Extract function name
227
228
229
230
        end_index = function_call_str.find(">")
        # If there's no ">" character, this is not a valid xml function call
        if end_index == -1:
            return None
231
        function_name = function_call_str[:end_index]
232
        param_config = find_tool_properties(self.tools, function_name)
233
        parameters = function_call_str[end_index + 1 :]
234
        param_dict = {}
235
        for match_text in self.tool_call_parameter_regex.findall(parameters):
236
237
            idx = match_text.index(">")
            param_name = match_text[:idx]
238
            param_value = str(match_text[idx + 1 :])
239
240
241
242
243
244
            # Remove prefix and trailing \n
            if param_value.startswith("\n"):
                param_value = param_value[1:]
            if param_value.endswith("\n"):
                param_value = param_value[:-1]

245
            param_dict[param_name] = self._convert_param_value(
246
247
                param_value, param_name, param_config, function_name
            )
248
249
        return ToolCall(
            type="function",
250
251
252
            function=FunctionCall(
                name=function_name, arguments=json.dumps(param_dict, ensure_ascii=False)
            ),
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
        )

    def _get_function_calls(self, model_output: str) -> list[str]:
        # Find all tool calls
        matched_ranges = self.tool_call_regex.findall(model_output)
        raw_tool_calls = [
            match[0] if match[0] else match[1] for match in matched_ranges
        ]

        # Back-off strategy if no tool_call tags found
        if len(raw_tool_calls) == 0:
            raw_tool_calls = [model_output]

        raw_function_calls = []
        for tool_call in raw_tool_calls:
268
            raw_function_calls.extend(self.tool_call_function_regex.findall(tool_call))
269
270
271
272
273
274
275
276
277
278
279
280
281

        function_calls = [
            match[0] if match[0] else match[1] for match in raw_function_calls
        ]
        return function_calls

    def extract_tool_calls(
        self,
        model_output: str,
        request: ChatCompletionRequest,
    ) -> ExtractedToolCallInformation:
        # Quick check to avoid unnecessary processing
        if self.tool_call_prefix not in model_output:
282
283
284
            return ExtractedToolCallInformation(
                tools_called=False, tool_calls=[], content=model_output
            )
285
286
287
288

        try:
            function_calls = self._get_function_calls(model_output)
            if len(function_calls) == 0:
289
290
291
                return ExtractedToolCallInformation(
                    tools_called=False, tool_calls=[], content=model_output
                )
292
293

            tool_calls = [
294
                self._parse_xml_function_call(function_call_str)
295
296
                for function_call_str in function_calls
            ]
297
            # Populate prev_tool_call_arr for serving layer to set finish_reason
298
299
300
            self.prev_tool_call_arr.clear()  # Clear previous calls
            for tool_call in tool_calls:
                if tool_call:
301
302
303
304
305
306
                    self.prev_tool_call_arr.append(
                        {
                            "name": tool_call.function.name,
                            "arguments": tool_call.function.arguments,
                        }
                    )
307
308
309

            # Extract content before tool calls
            content_index = model_output.find(self.tool_call_start_token)
310
311
            idx = model_output.find(self.tool_call_prefix)
            content_index = content_index if content_index >= 0 else idx
312
            content = model_output[:content_index]  # .rstrip()
313
            valid_tool_calls = [tc for tc in tool_calls if tc is not None]
314
            return ExtractedToolCallInformation(
315
316
                tools_called=(len(valid_tool_calls) > 0),
                tool_calls=valid_tool_calls,
317
318
319
320
321
                content=content if content else None,
            )

        except Exception:
            logger.exception("Error in extracting tool call from response.")
322
323
324
            return ExtractedToolCallInformation(
                tools_called=False, tool_calls=[], content=model_output
            )
325
326
327
328
329
330
331
332
333
334

    def extract_tool_calls_streaming(
        self,
        previous_text: str,
        current_text: str,
        delta_text: str,
        previous_token_ids: Sequence[int],
        current_token_ids: Sequence[int],
        delta_token_ids: Sequence[int],
        request: ChatCompletionRequest,
335
    ) -> DeltaMessage | None:
336
337
338
339
340
341
        # Store request for type conversion
        if not previous_text:
            self._reset_streaming_state()
            self.streaming_request = request

        # If no delta text, return None unless it's an EOS token after tools
342
343
        if not delta_text:
            # Check if this is an EOS token after all tool calls are complete
344
345
            # Check for tool calls in text even if is_tool_call_started
            # is False (might have been reset after processing all tools)
346
            if delta_token_ids and self.tool_call_end_token_id not in delta_token_ids:
347
348
                # Count complete tool calls
                complete_calls = len(
349
350
                    self.tool_call_complete_regex.findall(current_text)
                )
351
352
353

                # If we have completed tool calls and populated
                # prev_tool_call_arr
354
                if complete_calls > 0 and len(self.prev_tool_call_arr) > 0:
355
                    # Check if all tool calls are closed
356
                    open_calls = current_text.count(
357
358
                        self.tool_call_start_token
                    ) - current_text.count(self.tool_call_end_token)
359
                    if open_calls == 0:
360
                        # Return empty delta for finish_reason processing
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
                        return DeltaMessage(content="")
                elif not self.is_tool_call_started and current_text:
                    # This is a regular content response that's now complete
                    return DeltaMessage(content="")
            return None

        # Update accumulated text
        self.accumulated_text = current_text

        # Check if we need to advance to next tool
        if self.json_closed and not self.in_function:
            # Check if this tool call has ended
            tool_ends = current_text.count(self.tool_call_end_token)
            if tool_ends > self.current_tool_index:
                # This tool has ended, advance to next
                self.current_tool_index += 1
                self.header_sent = False
                self.param_count = 0
                self.json_started = False
                self.json_closed = False
381
                self.accumulated_params = {}
382
383

                # Check if there are more tool calls
384
385
                tool_starts = current_text.count(self.tool_call_start_token)
                if self.current_tool_index >= tool_starts:
386
387
388
389
390
391
392
393
                    # No more tool calls
                    self.is_tool_call_started = False
                # Continue processing next tool
                return None

        # Handle normal content before tool calls
        if not self.is_tool_call_started:
            # Check if tool call is starting
394
395
396
397
            if (
                self.tool_call_start_token_id in delta_token_ids
                or self.tool_call_start_token in delta_text
            ):
398
399
400
                self.is_tool_call_started = True
                # Return any content before the tool call
                if self.tool_call_start_token in delta_text:
401
402
403
                    content_before = delta_text[
                        : delta_text.index(self.tool_call_start_token)
                    ]
404
405
406
407
408
                    if content_before:
                        return DeltaMessage(content=content_before)
                return None
            else:
                # Check if we're between tool calls - skip whitespace
409
410
411
412
                if (
                    current_text.rstrip().endswith(self.tool_call_end_token)
                    and delta_text.strip() == ""
                ):
413
414
415
416
417
418
419
420
421
422
423
424
425
426
                    # We just ended a tool call, skip whitespace
                    return None
                # Normal content, no tool call
                return DeltaMessage(content=delta_text)

        # Check if we're between tool calls (waiting for next one)
        # Count tool calls we've seen vs processed
        tool_starts_count = current_text.count(self.tool_call_start_token)
        if self.current_tool_index >= tool_starts_count:
            # We're past all tool calls, shouldn't be here
            return None

        # We're in a tool call, find the current tool call portion
        # Need to find the correct tool call based on current_tool_index
427
        tool_start_positions: list[int] = []
428
429
430
431
432
        idx = 0
        while True:
            idx = current_text.find(self.tool_call_start_token, idx)
            if idx == -1:
                break
433
            tool_start_positions.append(idx)
434
435
            idx += len(self.tool_call_start_token)

436
        if self.current_tool_index >= len(tool_start_positions):
437
438
439
            # No more tool calls to process yet
            return None

440
        tool_start_idx = tool_start_positions[self.current_tool_index]
441
        # Find where this tool call ends (or current position if not ended yet)
442
        tool_end_idx = current_text.find(self.tool_call_end_token, tool_start_idx)
443
444
445
        if tool_end_idx == -1:
            tool_text = current_text[tool_start_idx:]
        else:
446
447
448
            tool_text = current_text[
                tool_start_idx : tool_end_idx + len(self.tool_call_end_token)
            ]
449
450
451
452

        # Looking for function header
        if not self.header_sent:
            if self.tool_call_prefix in tool_text:
453
                func_start = tool_text.find(self.tool_call_prefix) + len(
454
455
                    self.tool_call_prefix
                )
456
457
458
459
460
                func_end = tool_text.find(">", func_start)

                if func_end != -1:
                    # Found complete function name
                    self.current_function_name = tool_text[func_start:func_end]
461
                    self.current_tool_id = self._generate_tool_call_id()
462
463
464
                    self.header_sent = True
                    self.in_function = True

465
466
467
468
469
470
471
472
                    # Always append — each tool call is a separate
                    # invocation even if the function name is the same
                    # (e.g. two consecutive "read" calls).
                    self.prev_tool_call_arr.append(
                        {
                            "name": self.current_function_name,
                            "arguments": "{}",
                        }
473
                    )
474
475
476
477
478
479
480

                    # Initialize streamed args tracking for this tool.
                    # The serving layer reads streamed_args_for_tool to
                    # compute remaining arguments at stream end. Without
                    # this, IndexError occurs when the serving layer
                    # accesses streamed_args_for_tool[index].
                    self.streamed_args_for_tool.append("")
481
482

                    # Send header with function info
483
484
485
486
487
488
489
490
491
492
493
494
                    return DeltaMessage(
                        tool_calls=[
                            DeltaToolCall(
                                index=self.current_tool_index,
                                id=self.current_tool_id,
                                function=DeltaFunctionCall(
                                    name=self.current_function_name, arguments=""
                                ),
                                type="function",
                            )
                        ]
                    )
495
496
497
498
            return None

        # We've sent header, now handle function body
        if self.in_function:
499
500
501
502
503
504
            # Always send opening brace first, regardless of whether
            # parameter_prefix is in the current delta. With speculative
            # decoding, a single delta may contain both the opening brace
            # and parameter data; skipping "{" here would desync
            # json_started from what was actually streamed.
            if not self.json_started:
505
                self.json_started = True
506
                self.streamed_args_for_tool[self.current_tool_index] += "{"
507
508
509
510
511
512
513
514
                return DeltaMessage(
                    tool_calls=[
                        DeltaToolCall(
                            index=self.current_tool_index,
                            function=DeltaFunctionCall(arguments="{"),
                        )
                    ]
                )
515

516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
            # Find all parameter start positions in current tool_text
            param_starts = []
            search_idx = 0
            while True:
                search_idx = tool_text.find(self.parameter_prefix, search_idx)
                if search_idx == -1:
                    break
                param_starts.append(search_idx)
                search_idx += len(self.parameter_prefix)

            # Process ALL complete params in a loop (spec decode fix).
            # With speculative decoding a single delta can deliver
            # multiple complete parameters at once. The old single-pass
            # code would process one and ``return None`` if the next was
            # incomplete — skipping any already-complete params that
            # preceded it. Using a loop with ``break`` instead ensures
            # we emit every complete parameter before yielding control.
            json_fragments = []
            while not self.in_param and self.param_count < len(param_starts):
                param_idx = param_starts[self.param_count]
                param_start = param_idx + len(self.parameter_prefix)
                remaining = tool_text[param_start:]

                if ">" not in remaining:
                    break

                name_end = remaining.find(">")
                current_param_name = remaining[:name_end]

                value_start = param_start + name_end + 1
                value_text = tool_text[value_start:]
                if value_text.startswith("\n"):
                    value_text = value_text[1:]

                param_end_idx = value_text.find(self.parameter_end_token)
                if param_end_idx == -1:
                    next_param_idx = value_text.find(self.parameter_prefix)
                    func_end_idx = value_text.find(self.function_end_token)

                    if next_param_idx != -1 and (
                        func_end_idx == -1 or next_param_idx < func_end_idx
                    ):
                        param_end_idx = next_param_idx
                    elif func_end_idx != -1:
                        param_end_idx = func_end_idx
                    else:
                        # Fallback for malformed XML where </function>
                        # is missing. Use </tool_call> as a delimiter
                        # if present in the value so we don't include
                        # the closing tag as part of the param value.
                        tool_end_in_value = value_text.find(self.tool_call_end_token)
                        if tool_end_in_value != -1:
                            param_end_idx = tool_end_in_value
                        else:
                            # Parameter incomplete — break so we still
                            # emit any fragments accumulated by earlier
                            # loop iterations.
                            break

                if param_end_idx == -1:
                    break

                param_value = value_text[:param_end_idx]
                if param_value.endswith("\n"):
                    param_value = param_value[:-1]

                self.current_param_name = current_param_name
                self.accumulated_params[current_param_name] = param_value

585
586
                param_config = find_tool_properties(
                    self.tools, self.current_function_name or ""
587
                )
588

589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
                converted_value = self._convert_param_value(
                    param_value,
                    current_param_name,
                    param_config,
                    self.current_function_name or "",
                )

                serialized_value = json.dumps(converted_value, ensure_ascii=False)

                if self.param_count == 0:
                    json_fragment = f'"{current_param_name}": {serialized_value}'
                else:
                    json_fragment = f', "{current_param_name}": {serialized_value}'

                self.param_count += 1
                json_fragments.append(json_fragment)

            if json_fragments:
                combined = "".join(json_fragments)

                if self.current_tool_index < len(self.streamed_args_for_tool):
                    self.streamed_args_for_tool[self.current_tool_index] += combined
                else:
                    logger.warning(
                        "streamed_args_for_tool out of sync: index=%d len=%d",
                        self.current_tool_index,
                        len(self.streamed_args_for_tool),
                    )

                return DeltaMessage(
                    tool_calls=[
                        DeltaToolCall(
                            index=self.current_tool_index,
                            function=DeltaFunctionCall(arguments=combined),
                        )
                    ]
                )

            # Check for function end AFTER processing parameters.
            # This ordering is critical: with speculative decoding a
            # burst can deliver the final parameter value together with
            # </function>. If the close check ran first it would emit
            # "}" and set in_function=False before the parameter loop
            # ever ran, causing the parameter to be silently dropped.
633
634
635
            if not self.json_closed and self.function_end_token in tool_text:
                self.json_closed = True

636
                func_start = tool_text.find(self.tool_call_prefix) + len(
637
638
639
                    self.tool_call_prefix
                )
                func_content_end = tool_text.find(self.function_end_token, func_start)
640
641
642
643
                if func_content_end != -1:
                    func_content = tool_text[func_start:func_content_end]
                    try:
                        parsed_tool = self._parse_xml_function_call(
644
645
                            func_content,
                        )
646
647
648
649
650
651
                        if parsed_tool and self.current_tool_index < len(
                            self.prev_tool_call_arr
                        ):
                            self.prev_tool_call_arr[self.current_tool_index][
                                "arguments"
                            ] = parsed_tool.function.arguments
652
                    except Exception:
653
654
655
656
657
658
659
660
661
662
663
664
665
666
                        logger.debug(
                            "Failed to parse tool call during streaming: %s",
                            tool_text,
                            exc_info=True,
                        )

                if self.current_tool_index < len(self.streamed_args_for_tool):
                    self.streamed_args_for_tool[self.current_tool_index] += "}"
                else:
                    logger.warning(
                        "streamed_args_for_tool out of sync: index=%d len=%d",
                        self.current_tool_index,
                        len(self.streamed_args_for_tool),
                    )
667

668
669
670
671
672
673
674
675
                result = DeltaMessage(
                    tool_calls=[
                        DeltaToolCall(
                            index=self.current_tool_index,
                            function=DeltaFunctionCall(arguments="}"),
                        )
                    ]
                )
676
677
678

                self.in_function = False
                self.json_closed = True
679
                self.accumulated_params = {}
680
681
682

                return result

683
        return None