step3p5_tool_parser.py 44.5 KB
Newer Older
csy0225's avatar
csy0225 committed
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ast
import json
luopl's avatar
luopl committed
5
import uuid
csy0225's avatar
csy0225 committed
6
7
from collections.abc import Sequence
from typing import Any
luopl's avatar
luopl committed
8
# from xml.parsers.expat import ParserCreate
csy0225's avatar
csy0225 committed
9
10
11

import regex as re

luopl's avatar
luopl committed
12
# from vllm.entrypoints.chat_utils import make_tool_call_id
csy0225's avatar
csy0225 committed
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from vllm.entrypoints.openai.chat_completion.protocol import (
    ChatCompletionRequest,
    ChatCompletionToolsParam,
)
from vllm.entrypoints.openai.engine.protocol import (
    DeltaFunctionCall,
    DeltaMessage,
    DeltaToolCall,
    ExtractedToolCallInformation,
    FunctionCall,
    ToolCall,
)
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import (
    ToolParser,
)

logger = init_logger(__name__)


luopl's avatar
luopl committed
34
35
36
class Step3p5ToolParser(ToolParser):
    def __init__(self, tokenizer: TokenizerLike):
        super().__init__(tokenizer)
csy0225's avatar
csy0225 committed
37

luopl's avatar
luopl committed
38
39
40
41
42
        self.current_tool_name_sent: bool = False
        self.prev_tool_call_arr: list[dict] = []
        # Override base class type - we use string IDs for tool calls
        self.current_tool_id: str | None = None  # type: ignore
        self.streamed_args_for_tool: list[str] = []
csy0225's avatar
csy0225 committed
43

luopl's avatar
luopl committed
44
        # Sentinel tokens for streaming mode
csy0225's avatar
csy0225 committed
45
46
        self.tool_call_start_token: str = "<tool_call>"
        self.tool_call_end_token: str = "</tool_call>"
luopl's avatar
luopl committed
47
        self.tool_call_prefix: str = "<function="
csy0225's avatar
csy0225 committed
48
        self.function_end_token: str = "</function>"
luopl's avatar
luopl committed
49
        self.parameter_prefix: str = "<parameter="
csy0225's avatar
csy0225 committed
50
        self.parameter_end_token: str = "</parameter>"
luopl's avatar
luopl committed
51
52
        self.is_tool_call_started: bool = False
        self.failed_count: int = 0
csy0225's avatar
csy0225 committed
53

luopl's avatar
luopl committed
54
55
        # Enhanced streaming state - reset for each new message
        self._reset_streaming_state()
csy0225's avatar
csy0225 committed
56

luopl's avatar
luopl committed
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
        # Regex patterns
        self.tool_call_complete_regex = re.compile(
            r"<tool_call>(.*?)</tool_call>", re.DOTALL
        )
        self.tool_call_function_regex = re.compile(
            r"<function(?:=|\s+)?(.*?)</function>", re.DOTALL
        )
        self.tool_call_parameter_regex = re.compile(
            r"<parameter=(.*?)</parameter>", re.DOTALL
        )

        if not self.model_tokenizer:
            raise ValueError(
                "The model tokenizer must be passed to the ToolParser "
                "constructor during construction."
            )

        self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token)
        self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)

        if self.tool_call_start_token_id is None or self.tool_call_end_token_id is None:
            raise RuntimeError(
                "Step3p5 RL Tool parser could not locate tool call start/end "
                "tokens in the tokenizer!"
            )

        # Get EOS token ID for EOS detection
        self.eos_token_id = getattr(self.model_tokenizer, "eos_token_id", None)

        logger.info(
            "vLLM Successfully import tool parser %s !", self.__class__.__name__
        )

    def _generate_tool_call_id(self) -> str:
        """Generate a unique tool call ID."""
        return f"call_{uuid.uuid4().hex[:24]}"

    def parser_should_check_for_unstreamed_tool_arg_tokens(self) -> bool:
csy0225's avatar
csy0225 committed
95
        """
luopl's avatar
luopl committed
96
97
98
        Skip the remaining_call calculation in serving
        """
        return False
csy0225's avatar
csy0225 committed
99

luopl's avatar
luopl committed
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
    def _reset_streaming_state(self):
        """Reset all streaming state for a new request."""
        self._processed_length: int = 0  # Position of last processed character
        self._tool_call_index: int = 0  # Number of tool calls processed so far
        self.streaming_request = None  # Current request being processed

    def _get_arguments_config(
        self, func_name: str, tools: list[ChatCompletionToolsParam] | None
    ) -> dict:
        """Extract argument configuration for a function."""
        if tools is None:
            return {}
        for config in tools:
            if not hasattr(config, "type") or not (
                hasattr(config, "function") and hasattr(config.function, "name")
            ):
                continue
            if config.type == "function" and config.function.name == func_name:
                if not hasattr(config.function, "parameters"):
                    return {}
                params = config.function.parameters
                if isinstance(params, dict) and "properties" in params:
                    return params["properties"]
                elif isinstance(params, dict):
                    return params
                else:
                    return {}
        logger.warning("Tool '%s' is not defined in the tools list.", func_name)
        return {}

    def _convert_param_value(
        self, param_value: str, param_name: str, param_config: dict, func_name: str
    ) -> Any:
        """Convert parameter value based on its type in the schema."""
        # Handle null value for any type
        if param_value.lower() == "null":
            return None
csy0225's avatar
csy0225 committed
137

luopl's avatar
luopl committed
138
139
140
141
142
143
144
145
146
147
        if param_name not in param_config:
            if param_config != {}:
                logger.warning(
                    "Parsed parameter '%s' is not defined in the tool "
                    "parameters for tool '%s', directly returning the "
                    "string value.",
                    param_name,
                    func_name,
                )
            return param_value
csy0225's avatar
csy0225 committed
148

luopl's avatar
luopl committed
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
        if (
            isinstance(param_config[param_name], dict)
            and "type" in param_config[param_name]
        ):
            param_type = str(param_config[param_name]["type"]).strip().lower()
        else:
            param_type = "string"
        if param_type in ["string", "str", "text", "varchar", "char", "enum"]:
            return param_value
        elif (
            param_type.startswith("int")
            or param_type.startswith("uint")
            or param_type.startswith("long")
            or param_type.startswith("short")
            or param_type.startswith("unsigned")
        ):
csy0225's avatar
csy0225 committed
165
            try:
luopl's avatar
luopl committed
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
                return int(param_value)
            except (ValueError, TypeError):
                try:
                    float_value = float(param_value)
                    if float_value.is_integer():
                        return int(float_value)
                except (ValueError, TypeError):
                    pass
                try:
                    literal_value = ast.literal_eval(param_value)
                    if isinstance(literal_value, bool):
                        return int(literal_value)
                    if isinstance(literal_value, (int, float)):
                        return (
                            int(literal_value)
                            if float(literal_value).is_integer()
                            else literal_value
csy0225's avatar
csy0225 committed
183
                        )
luopl's avatar
luopl committed
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
                except (ValueError, SyntaxError, TypeError):
                    pass
                logger.warning(
                    "Parsed value '%s' of parameter '%s' is not an integer "
                    "in tool '%s', returning raw string.",
                    param_value,
                    param_name,
                    func_name,
                )
                return param_value
        elif param_type.startswith("num") or param_type.startswith("float"):
            try:
                float_param_value = float(param_value)
                return (
                    float_param_value
                    if float_param_value - int(float_param_value) != 0
                    else int(float_param_value)
                )
            except (ValueError, TypeError):
                try:
                    literal_value = ast.literal_eval(param_value)
                    if isinstance(literal_value, (int, float)):
                        return (
                            float(literal_value)
                            if float(literal_value) - int(float(literal_value)) != 0
                            else int(float(literal_value))
                        )
                except (ValueError, SyntaxError, TypeError):
                    pass
                logger.warning(
                    "Parsed value '%s' of parameter '%s' is not a float "
                    "in tool '%s', returning raw string.",
                    param_value,
                    param_name,
                    func_name,
                )
                return param_value
        elif param_type in ["boolean", "bool", "binary"]:
            normalized_value = param_value.strip().lower()
            if normalized_value in ["true", "false"]:
                return normalized_value == "true"
            if normalized_value in ["1", "0"]:
                return normalized_value == "1"
            try:
                literal_value = ast.literal_eval(param_value)
                if isinstance(literal_value, bool):
                    return literal_value
            except (ValueError, SyntaxError, TypeError):
                pass
            logger.warning(
                "Parsed value '%s' of parameter '%s' is not a boolean "
                "in tool '%s', returning raw string.",
                param_value,
                param_name,
                func_name,
            )
            return param_value
        else:
            if (
                param_type in ["object", "array", "arr"]
                or param_type.startswith("dict")
                or param_type.startswith("list")
            ):
                try:
                    param_value = json.loads(param_value)
                    return param_value
                except (json.JSONDecodeError, TypeError, ValueError):
                    try:
                        literal_value = ast.literal_eval(param_value)
                        if isinstance(literal_value, (list, dict)):
                            return literal_value
                        if isinstance(literal_value, (tuple, set)):
                            return list(literal_value)
                    except (ValueError, SyntaxError, TypeError):
                        pass
                    logger.warning(
                        "Parsed value '%s' of parameter '%s' cannot be parsed "
                        "as JSON in tool '%s', returning raw string.",
                        param_value,
                        param_name,
                        func_name,
csy0225's avatar
csy0225 committed
265
                    )
luopl's avatar
luopl committed
266
267
268
269
270
                    return param_value
            try:
                literal_value = ast.literal_eval(param_value)  # safer
                if isinstance(literal_value, (tuple, set)):
                    return list(literal_value)
csy0225's avatar
csy0225 committed
271
                if (
luopl's avatar
luopl committed
272
273
                    isinstance(literal_value, (list, dict, str, int, float, bool))
                    or literal_value is None
csy0225's avatar
csy0225 committed
274
                ):
luopl's avatar
luopl committed
275
276
277
278
279
280
281
282
283
                    return literal_value
            except (ValueError, SyntaxError, TypeError):
                pass
            logger.warning(
                "Parsed value '%s' of parameter '%s' cannot be converted via "
                "Python `ast.literal_eval()` in tool '%s', returning raw string.",
                param_value,
                param_name,
                func_name,
csy0225's avatar
csy0225 committed
284
            )
luopl's avatar
luopl committed
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
            return param_value

    def _parse_parameters_fallback(
        self,
        parameters: str,
        allowed_param_names: set[str] | None = None,
    ) -> list[tuple[str, str]]:
        """Fallback parser for malformed parameter tags."""
        param_pairs: list[tuple[str, str]] = []
        pos = 0
        while True:
            start = parameters.find(self.parameter_prefix, pos)
            if start == -1:
                break
            name_start = start + len(self.parameter_prefix)
            name_end = parameters.find(">", name_start)
            if name_end == -1:
                newline_idx = parameters.find("\n", name_start)
                end_tag = parameters.find(self.parameter_end_token, name_start)
                next_param = parameters.find(self.parameter_prefix, name_start)
                candidates = [
                    idx for idx in [newline_idx, end_tag, next_param] if idx != -1
                ]
                if not candidates:
                    break
                name_end = min(candidates)
                value_start = name_end
            else:
                value_start = name_end + 1
            param_name = parameters[name_start:name_end].strip()
            next_param = parameters.find(self.parameter_prefix, value_start)
            end_tag = parameters.find(self.parameter_end_token, value_start)
            if end_tag == -1 or (next_param != -1 and next_param < end_tag):
                end = next_param if next_param != -1 else len(parameters)
                pos = end
            else:
                end = end_tag
                pos = end + len(self.parameter_end_token)
            param_value = parameters[value_start:end]
            if allowed_param_names is None or param_name in allowed_param_names:
                param_pairs.append((param_name, param_value))
        return param_pairs

    def _is_valid_json_arguments(self, arguments: str) -> bool:
        """Check if arguments can be loaded as JSON."""
        try:
            json.loads(arguments)
        except Exception:
            return False
        return True

    def _parse_xml_function_call(
        self, function_call_str: str, tools: list[ChatCompletionToolsParam] | None
    ) -> ToolCall | None:
        # Extract function name
        end_index = function_call_str.index(">")

        # check empty function name
        function_name = function_call_str[:end_index].strip()
        if function_name.startswith("="):
            function_name = function_name.lstrip("=").strip()
        if not function_name or function_name.strip("'\"") == "":
            logger.warning("Empty function name in tool call.")
            return None
        if function_name[0] in "\"'" and function_name[-1] == function_name[0]:
            function_name = function_name[1:-1].strip()
            if not function_name:
                logger.warning("Empty function name in tool call.")
                return None

        param_config = self._get_arguments_config(function_name, tools)
        parameters = function_call_str[end_index + 1 :]
        param_dict = {}
        match_texts = self.tool_call_parameter_regex.findall(parameters)
        use_fallback = False
        if match_texts:
            for match_text in match_texts:
                if self.parameter_prefix in match_text or ">" not in match_text:
                    use_fallback = True
                    break
csy0225's avatar
csy0225 committed
365
        else:
luopl's avatar
luopl committed
366
            use_fallback = self.parameter_prefix in parameters
csy0225's avatar
csy0225 committed
367

luopl's avatar
luopl committed
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
        if use_fallback:
            allowed_param_names = (
                set(param_config.keys())
                if isinstance(param_config, dict) and param_config
                else None
            )
            param_pairs = self._parse_parameters_fallback(
                parameters, allowed_param_names
            )
        else:
            param_pairs = []
            for match_text in match_texts:
                idx = match_text.index(">")
                param_name = match_text[:idx]
                param_value = str(match_text[idx + 1 :])
                param_pairs.append((param_name, param_value))

        for param_name, param_value in param_pairs:
            # Remove prefix and trailing \n
            if param_value.startswith("\n"):
                param_value = param_value[1:]
            if param_value.endswith("\n"):
                param_value = param_value[:-1]

            param_dict[param_name] = self._convert_param_value(
                param_value, param_name, param_config, function_name
            )
csy0225's avatar
csy0225 committed
395

luopl's avatar
luopl committed
396
397
398
399
400
401
402
403
404
        try:
            arguments = json.dumps(param_dict, ensure_ascii=False)
        except Exception as e:
            logger.warning("Error in converting parameter value: %s", e)
            return None
        return ToolCall(
            type="function",
            function=FunctionCall(name=function_name, arguments=arguments),
        )
csy0225's avatar
csy0225 committed
405

luopl's avatar
luopl committed
406
407
408
    def _get_function_calls(self, model_output: str) -> list[str]:
        # Find all tool calls
        raw_tool_calls = self.tool_call_complete_regex.findall(model_output)
csy0225's avatar
csy0225 committed
409

luopl's avatar
luopl committed
410
411
412
        # if no closed tool_call tags found, return empty list
        if len(raw_tool_calls) == 0:
            return []
csy0225's avatar
csy0225 committed
413

luopl's avatar
luopl committed
414
415
416
417
        raw_function_calls = []
        for tool_call in raw_tool_calls:
            function_matches = self.tool_call_function_regex.findall(tool_call)
            raw_function_calls.extend(function_matches)
csy0225's avatar
csy0225 committed
418

luopl's avatar
luopl committed
419
        return raw_function_calls
csy0225's avatar
csy0225 committed
420

luopl's avatar
luopl committed
421
422
    def _check_format(self, model_output: str) -> bool:
        """Check if model output contains properly formatted tool call.
csy0225's avatar
csy0225 committed
423

luopl's avatar
luopl committed
424
425
426
427
        Requirements:
        1. Must have closed tool_call tags (<tool_call>...</tool_call>)
        2. Must have closed function tags (<function=...</function>)
        3. If parameter tags exist, they must be closed and correct
csy0225's avatar
csy0225 committed
428

luopl's avatar
luopl committed
429
430
431
432
433
434
        Returns True if the format is valid, False otherwise.
        """
        # Check 1: Must have closed tool_call tags
        tool_call_matches = self.tool_call_complete_regex.findall(model_output)
        if len(tool_call_matches) == 0:
            return False
csy0225's avatar
csy0225 committed
435

luopl's avatar
luopl committed
436
437
438
439
440
441
442
443
444
445
446
447
        # Check 2: Must have closed function tags within tool_call
        has_valid_function = False
        for tool_call_content in tool_call_matches:
            function_matches = self.tool_call_function_regex.findall(tool_call_content)
            if len(function_matches) > 0:
                has_valid_function = True
            # Check if there's an unclosed function tag
            if (
                self.tool_call_prefix in tool_call_content
                and self.function_end_token not in tool_call_content
            ):
                return False
csy0225's avatar
csy0225 committed
448

luopl's avatar
luopl committed
449
450
        if not has_valid_function:
            return False
csy0225's avatar
csy0225 committed
451

luopl's avatar
luopl committed
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
        # Check 3: If parameter tags exist, they must be closed and correct
        for tool_call_content in tool_call_matches:
            # Count opening and closing parameter tags
            param_open_count = tool_call_content.count(self.parameter_prefix)
            param_close_count = tool_call_content.count(self.parameter_end_token)

            # If there are parameter tags, they must be balanced
            if param_open_count > 0:
                if param_open_count != param_close_count:
                    return False
                # Check if all parameter tags are properly closed using regex
                param_matches = self.tool_call_parameter_regex.findall(
                    tool_call_content
                )
                if len(param_matches) != param_open_count:
                    return False
csy0225's avatar
csy0225 committed
468

luopl's avatar
luopl committed
469
        return True
csy0225's avatar
csy0225 committed
470

luopl's avatar
luopl committed
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
    def _wrap_missing_tool_call_tags(self, model_output: str) -> str:
        """Wrap bare <function=...></function> blocks with <tool_call> tags."""
        if (
            self.tool_call_prefix not in model_output
            or self.function_end_token not in model_output
        ):
            return model_output

        def _wrap_bare_functions(text: str) -> str:
            pos = 0
            wrapped_parts: list[str] = []
            while True:
                func_idx = text.find(self.tool_call_prefix, pos)
                if func_idx == -1:
                    wrapped_parts.append(text[pos:])
                    break
                end_idx = text.find(self.function_end_token, func_idx)
                if end_idx == -1:
                    wrapped_parts.append(text[pos:])
                    break
                end_idx += len(self.function_end_token)
                wrapped_parts.append(text[pos:func_idx])
                wrapped_parts.append(self.tool_call_start_token)
                wrapped_parts.append(text[func_idx:end_idx])
                wrapped_parts.append(self.tool_call_end_token)

                ws_idx = end_idx
                while ws_idx < len(text) and text[ws_idx].isspace():
                    ws_idx += 1
                if text.startswith(self.tool_call_end_token, ws_idx):
                    if ws_idx > end_idx:
                        wrapped_parts.append(text[end_idx:ws_idx])
                    pos = ws_idx + len(self.tool_call_end_token)
                else:
                    pos = end_idx
            return "".join(wrapped_parts)

        tool_call_ranges = [
            match.span()
            for match in self.tool_call_complete_regex.finditer(model_output)
        ]
        if not tool_call_ranges:
            return _wrap_bare_functions(model_output)

        wrapped_parts: list[str] = []
        pos = 0
        for start, end in tool_call_ranges:
            if start < pos:
csy0225's avatar
csy0225 committed
519
                continue
luopl's avatar
luopl committed
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
            wrapped_parts.append(_wrap_bare_functions(model_output[pos:start]))
            wrapped_parts.append(model_output[start:end])
            pos = end
        wrapped_parts.append(_wrap_bare_functions(model_output[pos:]))
        return "".join(wrapped_parts)

    def _normalize_prev_arguments(self, args_value: Any) -> Any:
        if isinstance(args_value, str):
            try:
                return json.loads(args_value)
            except (TypeError, ValueError, json.JSONDecodeError):
                return args_value
        return args_value

    def _update_prev_tool_call_state(self, tool_calls: list[ToolCall]) -> None:
        self.prev_tool_call_arr.clear()
        self.streamed_args_for_tool.clear()
        for tool_call in tool_calls:
            if not tool_call or not tool_call.function:
                continue
            args_value = tool_call.function.arguments
            if isinstance(args_value, str):
                args_json = args_value
            elif args_value is None:
                args_json = ""
            else:
                try:
                    args_json = json.dumps(args_value, ensure_ascii=False)
                except (TypeError, ValueError):
                    args_json = str(args_value)

            prev_args = self._normalize_prev_arguments(args_json)
            self.prev_tool_call_arr.append(
                {
                    "name": tool_call.function.name,
                    "arguments": prev_args,
                }
            )
            try:
                expected_args_json = json.dumps(prev_args, ensure_ascii=False)
            except (TypeError, ValueError):
                expected_args_json = args_json
csy0225's avatar
csy0225 committed
562

luopl's avatar
luopl committed
563
564
565
566
567
            # Serving may subtract the latest delta length from
            # streamed_args_for_tool to detect unstreamed suffixes. Since this
            # parser emits full arguments at once, store expected+actual so
            # the subtraction yields expected_args_json and no resend occurs.
            self.streamed_args_for_tool.append(expected_args_json + args_json)
csy0225's avatar
csy0225 committed
568

luopl's avatar
luopl committed
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
    def extract_tool_calls(
        self,
        model_output: str,
        request: ChatCompletionRequest,
    ) -> ExtractedToolCallInformation:
        try:
            origin_model_output = model_output
            try:
                # Fallback: handle outputs without <tool_call> wrapper.
                origin_model_output = self._wrap_missing_tool_call_tags(
                    origin_model_output
                )
                model_output = origin_model_output
            except Exception:
                pass

            # Use streaming-like approach: process position by position
            valid_tool_calls = []
            content_parts = []
            processed_length = 0

            while processed_length < len(model_output):
                # Find next tool call start
                tool_start_idx = self._find_tool_call_start(
                    model_output, processed_length
                )
csy0225's avatar
csy0225 committed
595

luopl's avatar
luopl committed
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
                # Case 1: No more tool calls - add remaining as content
                if tool_start_idx == -1:
                    remaining = model_output[processed_length:]
                    if remaining:
                        content_parts.append(remaining)
                    break

                # Case 2: Content before tool call
                if tool_start_idx > processed_length:
                    content_before = model_output[processed_length:tool_start_idx]
                    # Skip whitespace-only content between tool calls
                    # Check if we just ended a tool call and this is pure whitespace
                    if processed_length > 0:
                        text_before = model_output[:processed_length]
                        if (
                            text_before.rstrip().endswith(self.tool_call_end_token)
                            and content_before.strip() == ""
                        ):
                            # Skip whitespace between tool calls
                            pass
                        else:
                            content_parts.append(content_before)
                    else:
                        content_parts.append(content_before)
csy0225's avatar
csy0225 committed
620

luopl's avatar
luopl committed
621
622
623
624
                # Case 3: Try to find complete tool call
                tool_end_idx = self._find_first_complete_tool_call_end(
                    model_output, tool_start_idx
                )
csy0225's avatar
csy0225 committed
625

luopl's avatar
luopl committed
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
                # If tool call is incomplete - add remaining as content and stop
                if tool_end_idx == -1:
                    remaining = model_output[tool_start_idx:]
                    if remaining:
                        content_parts.append(remaining)
                    break

                # Extract and try to parse the complete tool call
                tool_call_text = model_output[tool_start_idx:tool_end_idx]
                parsed_result = self.extract_tool_calls_basic(tool_call_text, request)

                # If parsing succeeded, record the tool call(s)
                if parsed_result.tools_called and parsed_result.tool_calls:
                    valid_tool_calls.extend(parsed_result.tool_calls)
                    processed_length = tool_end_idx
                else:
                    # Parsing failed - treat this tool call as content
                    content_parts.append(tool_call_text)
                    processed_length = tool_end_idx
csy0225's avatar
csy0225 committed
645

luopl's avatar
luopl committed
646
647
            # Populate prev_tool_call_arr for serving layer to set finish_reason
            self._update_prev_tool_call_state(valid_tool_calls)
csy0225's avatar
csy0225 committed
648

luopl's avatar
luopl committed
649
650
            # Combine content parts
            content = "".join(content_parts) if content_parts else None
csy0225's avatar
csy0225 committed
651

luopl's avatar
luopl committed
652
653
654
655
656
            return ExtractedToolCallInformation(
                tools_called=(len(valid_tool_calls) > 0),
                tool_calls=valid_tool_calls,
                content=content if content else None,
            )
csy0225's avatar
csy0225 committed
657

luopl's avatar
luopl committed
658
659
660
661
        except Exception:
            logger.warning("Error in extracting tool call from response.")
            return ExtractedToolCallInformation(
                tools_called=False, tool_calls=[], content=model_output
csy0225's avatar
csy0225 committed
662
663
            )

luopl's avatar
luopl committed
664
665
666
667
668
669
670
671
672
673
674
675
676
    def extract_tool_calls_basic(
        self,
        model_output: str,
        request: ChatCompletionRequest,
    ) -> ExtractedToolCallInformation:
        model_output = self._wrap_missing_tool_call_tags(model_output)
        # Quick check to avoid unnecessary processing
        if not self._check_format(model_output):
            tool_call_matches = self.tool_call_complete_regex.findall(model_output)
            if len(tool_call_matches) == 0:
                return ExtractedToolCallInformation(
                    tools_called=False, tool_calls=[], content=model_output
                )
csy0225's avatar
csy0225 committed
677

luopl's avatar
luopl committed
678
679
680
681
682
683
        try:
            function_calls = self._get_function_calls(model_output)
            if len(function_calls) == 0:
                return ExtractedToolCallInformation(
                    tools_called=False, tool_calls=[], content=model_output
                )
csy0225's avatar
csy0225 committed
684

luopl's avatar
luopl committed
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
            tool_calls: list[ToolCall] = []
            for function_call_str in function_calls:
                tool_call = self._parse_xml_function_call(
                    function_call_str, request.tools
                )
                if tool_call:
                    tool_calls.append(tool_call)
            if not tool_calls:
                return ExtractedToolCallInformation(
                    tools_called=False, tool_calls=[], content=model_output
                )
            for tool_call in tool_calls:
                if (
                    not tool_call.function
                    or tool_call.function.arguments is None
                    or not self._is_valid_json_arguments(tool_call.function.arguments)
                ):
                    logger.warning(
                        "Invalid JSON arguments in tool call, falling back to content."
                    )
                    return ExtractedToolCallInformation(
                        tools_called=False, tool_calls=[], content=model_output
                    )
csy0225's avatar
csy0225 committed
708

luopl's avatar
luopl committed
709
710
            # Populate prev_tool_call_arr for serving layer to set finish_reason
            self._update_prev_tool_call_state(tool_calls)
csy0225's avatar
csy0225 committed
711

luopl's avatar
luopl committed
712
713
714
            # Extract content before tool calls
            content_index = model_output.find(self.tool_call_start_token)
            content = model_output[:content_index]  # .rstrip()
csy0225's avatar
csy0225 committed
715

luopl's avatar
luopl committed
716
717
718
719
720
            return ExtractedToolCallInformation(
                tools_called=(len(tool_calls) > 0),
                tool_calls=tool_calls,
                content=content if content else None,
            )
csy0225's avatar
csy0225 committed
721

luopl's avatar
luopl committed
722
723
724
725
726
        except Exception:
            logger.warning("Error in extracting tool call from response.")
            return ExtractedToolCallInformation(
                tools_called=False, tool_calls=[], content=model_output
            )
csy0225's avatar
csy0225 committed
727

luopl's avatar
luopl committed
728
729
    def _find_first_complete_tool_call_end(self, text: str, start_pos: int = 0) -> int:
        """Find the end position of the first complete tool call.
csy0225's avatar
csy0225 committed
730
731

        Args:
luopl's avatar
luopl committed
732
733
            text: Text to search in
            start_pos: Position to start searching from
csy0225's avatar
csy0225 committed
734
735

        Returns:
luopl's avatar
luopl committed
736
            Position after the first </tool_call> tag, or -1 if incomplete
csy0225's avatar
csy0225 committed
737

luopl's avatar
luopl committed
738
739
740
741
742
743
744
745
746
747
748
749
750
751
        Example:
            "<tool_call>...</tool_call>..." returns position after </tool_call>
        """
        # Find tool call start
        start_idx = text.find(self.tool_call_start_token, start_pos)
        if start_idx == -1:
            return -1

        # Find matching end token
        end_idx = text.find(
            self.tool_call_end_token, start_idx + len(self.tool_call_start_token)
        )
        if end_idx == -1:
            return -1  # Incomplete tool call
csy0225's avatar
csy0225 committed
752

luopl's avatar
luopl committed
753
754
        # Return position after end token
        return end_idx + len(self.tool_call_end_token)
csy0225's avatar
csy0225 committed
755

luopl's avatar
luopl committed
756
757
    def _find_tool_call_start(self, text: str, start_pos: int = 0) -> int:
        """Find the start position of next tool call.
csy0225's avatar
csy0225 committed
758
759

        Args:
luopl's avatar
luopl committed
760
761
            text: Text to search in
            start_pos: Position to start searching from
csy0225's avatar
csy0225 committed
762
763

        Returns:
luopl's avatar
luopl committed
764
            Position of <tool_call> token, or -1 if not found
csy0225's avatar
csy0225 committed
765
        """
luopl's avatar
luopl committed
766
        return text.find(self.tool_call_start_token, start_pos)
csy0225's avatar
csy0225 committed
767

luopl's avatar
luopl committed
768
769
    def _extract_content_between_tool_calls_list(self, text: str) -> list[str]:
        """Extract content segments after each tool call.
csy0225's avatar
csy0225 committed
770

luopl's avatar
luopl committed
771
772
        For n tool calls, returns n segments where segment[i] is the content
        after tool_call[i] (before tool_call[i+1] or at the end).
csy0225's avatar
csy0225 committed
773

luopl's avatar
luopl committed
774
        Empty or whitespace-only segments are represented as empty string "".
csy0225's avatar
csy0225 committed
775
776

        Args:
luopl's avatar
luopl committed
777
            text: Text containing tool calls
csy0225's avatar
csy0225 committed
778
779

        Returns:
luopl's avatar
luopl committed
780
            List of content segments (one per tool call)
csy0225's avatar
csy0225 committed
781
        """
luopl's avatar
luopl committed
782
783
        content_segments = []
        pos = 0
csy0225's avatar
csy0225 committed
784

luopl's avatar
luopl committed
785
786
787
788
789
        while True:
            # Find end of current tool call
            end_pos = text.find(self.tool_call_end_token, pos)
            if end_pos == -1:
                break
csy0225's avatar
csy0225 committed
790

luopl's avatar
luopl committed
791
792
            # Move past the end token
            end_pos += len(self.tool_call_end_token)
csy0225's avatar
csy0225 committed
793

luopl's avatar
luopl committed
794
795
            # Find start of next tool call
            next_start = self._find_tool_call_start(text, end_pos)
csy0225's avatar
csy0225 committed
796

luopl's avatar
luopl committed
797
798
            # Extract content between current end and next start (or text end)
            content = text[end_pos:next_start] if next_start != -1 else text[end_pos:]
csy0225's avatar
csy0225 committed
799

luopl's avatar
luopl committed
800
801
            # Store content (empty string if whitespace-only)
            content_segments.append(content if content.strip() else "")
csy0225's avatar
csy0225 committed
802

luopl's avatar
luopl committed
803
804
805
            if next_start == -1:
                break
            pos = next_start
csy0225's avatar
csy0225 committed
806

luopl's avatar
luopl committed
807
        return content_segments
csy0225's avatar
csy0225 committed
808

luopl's avatar
luopl committed
809
810
811
812
    def _convert_tool_calls_to_deltas(
        self, tool_calls: list[ToolCall], starting_index: int = 0
    ) -> list[DeltaToolCall]:
        """Convert complete ToolCall list to DeltaToolCall list.
csy0225's avatar
csy0225 committed
813

luopl's avatar
luopl committed
814
        Returns complete tool calls without splitting into fragments.
csy0225's avatar
csy0225 committed
815

luopl's avatar
luopl committed
816
817
818
        Args:
            tool_calls: List of tool calls to convert
            starting_index: Starting index for tool calls (default 0)
csy0225's avatar
csy0225 committed
819

luopl's avatar
luopl committed
820
821
        Returns:
            List of DeltaToolCall with complete arguments
csy0225's avatar
csy0225 committed
822
        """
luopl's avatar
luopl committed
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
        delta_tool_calls = []
        for i, tool_call in enumerate[ToolCall](tool_calls):
            index = starting_index + i
            tool_id = self._generate_tool_call_id()

            # Create complete DeltaToolCall with full arguments
            delta_tool_calls.append(
                DeltaToolCall(
                    index=index,
                    id=tool_id,
                    function=DeltaFunctionCall(
                        name=tool_call.function.name,
                        arguments=tool_call.function.arguments,
                    ),
                    type="function",
csy0225's avatar
csy0225 committed
838
839
840
                )
            )

luopl's avatar
luopl committed
841
        return delta_tool_calls
csy0225's avatar
csy0225 committed
842

luopl's avatar
luopl committed
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
    def extract_tool_calls_streaming(
        self,
        previous_text: str,
        current_text: str,
        delta_text: str,
        previous_token_ids: Sequence[int],
        current_token_ids: Sequence[int],
        delta_token_ids: Sequence[int],
        request: ChatCompletionRequest,
    ) -> DeltaMessage | None:
        """Extract tool calls from streaming text using complete parsing.

        Strategy:
        1. Accumulate text in buffer and track processed position
        2. In each iteration, try to extract content or complete tool calls
        3. Parse complete tool calls using non-streaming method
        4. Convert parsed results to delta sequence
        5. Handle EOS token to flush incomplete tool calls as content
        """
        # Initialize state for new request
        if not previous_text:
            self._reset_streaming_state()
            self.streaming_request = request

        # Check for EOS token
        has_eos = (
            self.eos_token_id is not None
            and delta_token_ids
            and self.eos_token_id in delta_token_ids
        )
csy0225's avatar
csy0225 committed
873

luopl's avatar
luopl committed
874
875
876
877
878
879
880
        # If no delta text, check if we need to return empty delta for finish_reason
        if not delta_text and not has_eos:
            # Check if this is an EOS token after all tool calls are complete
            if delta_token_ids and self.tool_call_end_token_id not in delta_token_ids:
                # Count complete tool calls
                complete_calls = len(
                    self.tool_call_complete_regex.findall(current_text)
csy0225's avatar
csy0225 committed
881
882
                )

luopl's avatar
luopl committed
883
884
885
886
887
888
889
890
891
892
893
894
895
                # If we have completed tool calls and populated prev_tool_call_arr
                if complete_calls > 0 and len(self.prev_tool_call_arr) > 0:
                    # Check if all tool calls are closed
                    open_calls = current_text.count(
                        self.tool_call_start_token
                    ) - current_text.count(self.tool_call_end_token)
                    if open_calls == 0:
                        # Return empty delta for finish_reason processing
                        return DeltaMessage(content="")
            return None

        # Process all available content
        accumulated_deltas: list[DeltaMessage] = []
csy0225's avatar
csy0225 committed
896

luopl's avatar
luopl committed
897
898
899
        while self._has_unprocessed_content(current_text):
            # Try to process next chunk (content or tool call)
            delta = self._process_next_chunk(current_text)
csy0225's avatar
csy0225 committed
900

luopl's avatar
luopl committed
901
902
903
            if delta is None:
                # Cannot proceed further, need more tokens
                break
csy0225's avatar
csy0225 committed
904

luopl's avatar
luopl committed
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
            # Accumulate deltas
            if isinstance(delta, list):
                accumulated_deltas.extend(delta)
            else:
                accumulated_deltas.append(delta)

        # Handle EOS: flush any remaining incomplete tool calls as content
        if has_eos:
            remaining_delta = self._flush_remaining_content(current_text)
            if remaining_delta:
                accumulated_deltas.append(remaining_delta)
            # If no remaining content but we have tool calls, return empty delta
            elif len(self.prev_tool_call_arr) > 0:
                # Check if all tool calls are closed
                open_calls = current_text.count(
                    self.tool_call_start_token
                ) - current_text.count(self.tool_call_end_token)
                if open_calls == 0:
                    accumulated_deltas.append(DeltaMessage(content=""))

        # Return results
        return self._format_delta_result(accumulated_deltas)

    def _has_unprocessed_content(self, current_text: str) -> bool:
        """Check if there's unprocessed content in the buffer."""
        return self._processed_length < len(current_text)

    def _process_next_chunk(
        self, current_text: str
    ) -> DeltaMessage | list[DeltaMessage] | None:
        """Process next chunk: either regular content or a complete tool call.
csy0225's avatar
csy0225 committed
936

luopl's avatar
luopl committed
937
938
        Args:
            current_text: Current accumulated text
csy0225's avatar
csy0225 committed
939

luopl's avatar
luopl committed
940
941
942
943
944
945
946
947
        Returns:
            - DeltaMessage or list of DeltaMessage if processed successfully
            - None if cannot proceed (need more tokens)
        """
        # Find next tool call start
        tool_start_idx = self._find_tool_call_start(
            current_text, self._processed_length
        )
csy0225's avatar
csy0225 committed
948

luopl's avatar
luopl committed
949
950
951
952
953
        # Case 1: No tool call found - return remaining content
        if tool_start_idx == -1:
            return self._process_content(
                current_text, self._processed_length, len(current_text)
            )
csy0225's avatar
csy0225 committed
954

luopl's avatar
luopl committed
955
956
957
958
959
        # Case 2: Content before tool call
        if tool_start_idx > self._processed_length:
            return self._process_content(
                current_text, self._processed_length, tool_start_idx
            )
csy0225's avatar
csy0225 committed
960

luopl's avatar
luopl committed
961
962
963
964
965
        # Case 3: Tool call at current position
        # Find end of the first complete tool call
        tool_end_idx = self._find_first_complete_tool_call_end(
            current_text, tool_start_idx
        )
csy0225's avatar
csy0225 committed
966

luopl's avatar
luopl committed
967
968
969
        if tool_end_idx == -1:
            # Tool call incomplete, wait for more tokens
            return None
csy0225's avatar
csy0225 committed
970

luopl's avatar
luopl committed
971
972
973
974
975
976
977
978
979
        # Process complete tool call
        return self._process_complete_tool_calls(
            current_text, tool_start_idx, tool_end_idx
        )

    def _process_content(
        self, current_text: str, start_pos: int, end_pos: int
    ) -> DeltaMessage | None:
        """Process regular content (non-tool-call text).
csy0225's avatar
csy0225 committed
980
981

        Args:
luopl's avatar
luopl committed
982
983
984
            current_text: Current accumulated text
            start_pos: Start position in buffer
            end_pos: End position in buffer
csy0225's avatar
csy0225 committed
985
986

        Returns:
luopl's avatar
luopl committed
987
            DeltaMessage with content if non-empty
csy0225's avatar
csy0225 committed
988
        """
luopl's avatar
luopl committed
989
990
        if start_pos >= end_pos:
            return None
csy0225's avatar
csy0225 committed
991

luopl's avatar
luopl committed
992
993
994
995
996
997
        content = current_text[start_pos:end_pos]

        # Check if we're between tool calls - skip whitespace
        if start_pos > 0:
            # Check if text before start_pos ends with </tool_call>
            text_before = current_text[:start_pos]
csy0225's avatar
csy0225 committed
998
            if (
luopl's avatar
luopl committed
999
1000
                text_before.rstrip().endswith(self.tool_call_end_token)
                and content.strip() == ""
csy0225's avatar
csy0225 committed
1001
            ):
luopl's avatar
luopl committed
1002
1003
1004
                # We just ended a tool call, skip whitespace between tool calls
                self._processed_length = end_pos
                return None
csy0225's avatar
csy0225 committed
1005

luopl's avatar
luopl committed
1006
1007
1008
1009
        # Return content if non-empty
        if content:
            self._processed_length = end_pos
            return DeltaMessage(content=content)
csy0225's avatar
csy0225 committed
1010

luopl's avatar
luopl committed
1011
1012
1013
1014
1015
1016
        # Mark as processed even if empty
        self._processed_length = end_pos
        return None

    def _flush_remaining_content(self, current_text: str) -> DeltaMessage | None:
        """Flush any remaining unprocessed content as regular content.
csy0225's avatar
csy0225 committed
1017
1018

        Args:
luopl's avatar
luopl committed
1019
            current_text: Current accumulated text
csy0225's avatar
csy0225 committed
1020

luopl's avatar
luopl committed
1021
        Used when EOS token is encountered to handle incomplete tool calls.
csy0225's avatar
csy0225 committed
1022
        """
luopl's avatar
luopl committed
1023
        if not self._has_unprocessed_content(current_text):
csy0225's avatar
csy0225 committed
1024
1025
            return None

luopl's avatar
luopl committed
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
        remaining = current_text[self._processed_length :]
        if remaining:
            self._processed_length = len(current_text)
            return DeltaMessage(content=remaining)

        self._processed_length = len(current_text)
        return None

    def _format_delta_result(self, deltas: list[DeltaMessage]) -> DeltaMessage | None:
        """Format delta result for return.

        Merges all deltas into a single DeltaMessage.
csy0225's avatar
csy0225 committed
1038
1039

        Args:
luopl's avatar
luopl committed
1040
            deltas: List of delta messages
csy0225's avatar
csy0225 committed
1041
1042

        Returns:
luopl's avatar
luopl committed
1043
1044
            - None if empty
            - Single merged DeltaMessage with all content and tool_calls
csy0225's avatar
csy0225 committed
1045
        """
luopl's avatar
luopl committed
1046
1047
        if not deltas:
            return None
csy0225's avatar
csy0225 committed
1048

luopl's avatar
luopl committed
1049
1050
        if len(deltas) == 1:
            return deltas[0]
csy0225's avatar
csy0225 committed
1051

luopl's avatar
luopl committed
1052
1053
1054
        # Merge multiple deltas into one
        merged_content_parts = []
        merged_tool_calls = []
csy0225's avatar
csy0225 committed
1055

luopl's avatar
luopl committed
1056
1057
1058
1059
1060
        for delta in deltas:
            if delta.content:
                merged_content_parts.append(delta.content)
            if delta.tool_calls:
                merged_tool_calls.extend(delta.tool_calls)
csy0225's avatar
csy0225 committed
1061

luopl's avatar
luopl committed
1062
1063
        # Create merged DeltaMessage
        merged_content = "".join(merged_content_parts) if merged_content_parts else None
csy0225's avatar
csy0225 committed
1064

luopl's avatar
luopl committed
1065
1066
1067
1068
        # Build kwargs - only include tool_calls if non-empty
        kwargs: dict[str, Any] = {"content": merged_content}
        if merged_tool_calls:
            kwargs["tool_calls"] = merged_tool_calls
csy0225's avatar
csy0225 committed
1069

luopl's avatar
luopl committed
1070
        return DeltaMessage(**kwargs)
csy0225's avatar
csy0225 committed
1071

luopl's avatar
luopl committed
1072
1073
1074
1075
    def _process_complete_tool_calls(
        self, current_text: str, start_pos: int, end_pos: int
    ) -> list[DeltaMessage] | None:
        """Process complete tool calls and convert to delta sequence.
csy0225's avatar
csy0225 committed
1076

luopl's avatar
luopl committed
1077
1078
1079
1080
        Args:
            current_text: Current accumulated text
            start_pos: Start position (should be at <tool_call>)
            end_pos: End position (after </tool_call>)
csy0225's avatar
csy0225 committed
1081

luopl's avatar
luopl committed
1082
1083
1084
1085
1086
1087
        Returns:
            List of DeltaMessage if successful, None otherwise
        """
        try:
            # Extract text segment containing complete tool call(s)
            text_to_parse = current_text[start_pos:end_pos]
csy0225's avatar
csy0225 committed
1088

luopl's avatar
luopl committed
1089
1090
1091
            # Parse using non-streaming method
            result = self.extract_tool_calls_basic(
                text_to_parse, self.streaming_request
csy0225's avatar
csy0225 committed
1092
1093
            )

luopl's avatar
luopl committed
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
            # Case 1: Successfully parsed tool calls
            if result.tools_called and result.tool_calls:
                # Note: Due to _find_first_complete_tool_call_end, we typically
                # process only one tool call at a time
                # but we can also process multiple tool calls below
                deltas = self._build_tool_call_deltas(result.tool_calls, text_to_parse)
                self._update_state_after_tool_calls(result.tool_calls, end_pos)
                return deltas if deltas else None

            # Case 2: Parsing failed - treat as regular content
            self._processed_length = end_pos
            return [DeltaMessage(content=text_to_parse)]

        except Exception as e:
            # Exception during parsing - treat as content
            logger.debug("Failed to parse tool calls: %s, treating as content", e)
            self._processed_length = end_pos
            failed_text = current_text[start_pos:end_pos]
            return [DeltaMessage(content=failed_text)] if failed_text else None

    def _build_tool_call_deltas(
        self, tool_calls: list[ToolCall], parsed_text: str
    ) -> list[DeltaMessage]:
        """Build delta messages from parsed tool calls with interleaved content.
csy0225's avatar
csy0225 committed
1118

luopl's avatar
luopl committed
1119
1120
1121
        Args:
            tool_calls: List of parsed tool calls
            parsed_text: Original text that was parsed
csy0225's avatar
csy0225 committed
1122

luopl's avatar
luopl committed
1123
1124
1125
1126
1127
        Returns:
            List of DeltaMessage with tool calls and content interleaved
        """
        # Extract content segments between tool calls
        content_segments = self._extract_content_between_tool_calls_list(parsed_text)
csy0225's avatar
csy0225 committed
1128

luopl's avatar
luopl committed
1129
1130
1131
1132
        # Convert all tool calls to DeltaToolCall list
        delta_tool_calls = self._convert_tool_calls_to_deltas(
            tool_calls, self._tool_call_index
        )
csy0225's avatar
csy0225 committed
1133

luopl's avatar
luopl committed
1134
1135
        # Merge all content segments into a single string
        merged_content = "".join(content_segments)
csy0225's avatar
csy0225 committed
1136

luopl's avatar
luopl committed
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
        # Return a single DeltaMessage with all tool calls and content
        # Build kwargs - only include non-empty fields
        kwargs: dict[str, Any] = {}
        if merged_content:
            kwargs["content"] = merged_content
        if delta_tool_calls:
            kwargs["tool_calls"] = delta_tool_calls

        # Only return DeltaMessage if we have content or tool_calls
        if kwargs:
            return [DeltaMessage(**kwargs)]
        else:
            return []

    def _update_state_after_tool_calls(
        self, tool_calls: list[ToolCall], end_pos: int
    ) -> None:
        """Update internal state after processing tool calls.

        Args:
            tool_calls: List of processed tool calls
            end_pos: End position in buffer
csy0225's avatar
csy0225 committed
1159
        """
luopl's avatar
luopl committed
1160
1161
1162
1163
1164
1165
1166
1167
        # Update processed position
        self._processed_length = end_pos

        # Update tool call index
        self._tool_call_index += len(tool_calls)

        # Update prev_tool_call_arr for finish_reason
        self._update_prev_tool_call_state(tool_calls)