minimax_tool_parser.py 27.8 KB
Newer Older
1
2
3
4
5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import json
from collections.abc import Sequence
6
from typing import Any
7
8
9

import regex as re

10
from vllm.entrypoints.chat_utils import make_tool_call_id
11
12
13
14
15
16
17
18
19
from vllm.entrypoints.openai.protocol import (
    ChatCompletionRequest,
    DeltaFunctionCall,
    DeltaMessage,
    DeltaToolCall,
    ExtractedToolCallInformation,
    FunctionCall,
    ToolCall,
)
20
from vllm.logger import init_logger
21
from vllm.tokenizers import TokenizerLike
22
23
24
25
from vllm.tool_parsers.abstract_tool_parser import (
    ToolParser,
)
from vllm.tool_parsers.utils import extract_intermediate_diff
26
27
28
29
30

logger = init_logger(__name__)


class MinimaxToolParser(ToolParser):
31
    def __init__(self, tokenizer: TokenizerLike):
32
33
        super().__init__(tokenizer)

34
35
36
37
38
39
        # Initialize streaming state for tracking tool call progress
        self.streaming_state: dict[str, Any] = {
            "current_tool_index": -1,  # Index of current tool being processed
            "tool_ids": [],  # List of tool call IDs
            "sent_tools": [],  # List of tools that have been sent
        }
40

41
42
43
        # Define tool call tokens and patterns
        self.tool_call_start_token = "<tool_calls>"
        self.tool_call_end_token = "</tool_calls>"
44
        self.tool_call_regex = re.compile(
45
46
            r"<tool_calls>(.*?)</tool_calls>|<tool_calls>(.*)", re.DOTALL
        )
47
        self.thinking_tag_pattern = r"<think>(.*?)</think>"
48
49
50
51
52
53
        self.tool_name_pattern = re.compile(r'"name":\s*"([^"]+)"')
        self.tool_args_pattern = re.compile(r'"arguments":\s*')

        # Buffer for handling partial tool calls during streaming
        self.pending_buffer = ""
        self.in_thinking_tag = False
54
55
56
57

        if not self.model_tokenizer:
            raise ValueError(
                "The model tokenizer must be passed to the ToolParser "
58
59
                "constructor during construction."
            )
60

61
        # Get token IDs for tool call start/end tokens
62
        self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token)
63
64
        self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)

65
        if self.tool_call_start_token_id is None or self.tool_call_end_token_id is None:
66
67
            logger.warning(
                "Minimax Tool parser could not locate tool call start/end "
68
69
                "tokens in the tokenizer. Falling back to string matching."
            )
70
71
72

    def preprocess_model_output(self, model_output: str) -> str:
        """
73
        Preprocess model output by removing tool calls from thinking tags.
74

75
76
        Args:
            model_output: Raw model output string
77

78
79
        Returns:
            Preprocessed model output with tool calls removed from thinking tags
80
81
82
83
        """

        def remove_tool_calls_from_think(match):
            think_content = match.group(1)
84
85
86
            cleaned_content = re.sub(
                r"<tool_calls>.*?</tool_calls>", "", think_content, flags=re.DOTALL
            )
87
88
            return f"<think>{cleaned_content}</think>"

89
90
91
92
93
94
        return re.sub(
            self.thinking_tag_pattern,
            remove_tool_calls_from_think,
            model_output,
            flags=re.DOTALL,
        )
95

96
97
98
    def _clean_duplicate_braces(self, args_text: str) -> str:
        """
        Clean duplicate closing braces from arguments text.
99

100
101
        Args:
            args_text: Raw arguments text
102

103
104
105
106
107
108
109
110
111
112
113
114
115
        Returns:
            Cleaned arguments text with proper JSON formatting
        """
        args_text = args_text.strip()
        if not args_text:
            return args_text

        try:
            json.loads(args_text)
            return args_text
        except json.JSONDecodeError:
            pass

116
        while args_text.endswith("}}"):
117
118
119
120
121
122
123
124
125
126
127
128
            candidate = args_text[:-1]
            try:
                json.loads(candidate)
                return candidate
            except json.JSONDecodeError:
                args_text = candidate

        return args_text

    def _clean_delta_braces(self, delta_text: str) -> str:
        """
        Clean delta text by removing excessive closing braces.
129

130
131
        Args:
            delta_text: Delta text to clean
132

133
134
135
136
137
138
139
140
        Returns:
            Cleaned delta text
        """
        if not delta_text:
            return delta_text

        delta_stripped = delta_text.strip()

141
142
        if delta_stripped and all(c in "}\n\r\t " for c in delta_stripped):
            brace_count = delta_stripped.count("}")
143
            if brace_count > 1:
144
                return "}\n" if delta_text.endswith("\n") else "}"
145
146

        return delta_text
147
148
149
150
151
152

    def extract_tool_calls(
        self,
        model_output: str,
        request: ChatCompletionRequest,
    ) -> ExtractedToolCallInformation:
153
154
        """
        Extract tool calls from model output for non-streaming mode.
155

156
157
158
        Args:
            model_output: Complete model output
            request: Chat completion request
159

160
161
162
        Returns:
            ExtractedToolCallInformation containing tool calls and content
        """
163
164
165
        processed_output = self.preprocess_model_output(model_output)

        if self.tool_call_start_token not in processed_output:
166
167
168
            return ExtractedToolCallInformation(
                tools_called=False, tool_calls=[], content=model_output
            )
169
170

        try:
171
            function_call_tuples = self.tool_call_regex.findall(processed_output)
172
173
174
175
176

            raw_function_calls = []
            for match in function_call_tuples:
                tool_call_content = match[0] if match[0] else match[1]
                if tool_call_content.strip():
177
                    lines = tool_call_content.strip().split("\n")
178
179
                    for line in lines:
                        line = line.strip()
180
                        if line and line.startswith("{") and line.endswith("}"):
181
182
183
184
185
186
187
188
189
190
                            try:
                                parsed_call = json.loads(line)
                                raw_function_calls.append(parsed_call)
                            except json.JSONDecodeError:
                                continue

            tool_calls = []
            for function_call in raw_function_calls:
                if "name" in function_call and "arguments" in function_call:
                    tool_calls.append(
191
192
193
194
195
196
197
198
199
200
                        ToolCall(
                            type="function",
                            function=FunctionCall(
                                name=function_call["name"],
                                arguments=json.dumps(
                                    function_call["arguments"], ensure_ascii=False
                                ),
                            ),
                        )
                    )
201
202
203
204
205
206

            processed_pos = processed_output.find(self.tool_call_start_token)
            if processed_pos != -1:
                processed_content = processed_output[:processed_pos].strip()

                if processed_content:
207
                    lines = processed_content.split("\n")
208
209
210
211
212
                    for line in reversed(lines):
                        line = line.strip()
                        if line:
                            pos = model_output.find(line)
                            if pos != -1:
213
                                content = model_output[: pos + len(line)]
214
215
216
217
218
219
220
221
222
223
224
                                break
                    else:
                        content = ""
                else:
                    content = ""
            else:
                content = model_output

            return ExtractedToolCallInformation(
                tools_called=len(tool_calls) > 0,
                tool_calls=tool_calls,
225
226
                content=content.strip() if content.strip() else None,
            )
227
228
229

        except Exception:
            logger.exception(
230
231
232
233
234
                "An unexpected error occurred during tool call extraction."
            )
            return ExtractedToolCallInformation(
                tools_called=False, tool_calls=[], content=model_output
            )
235

236
237
238
    def _update_thinking_state(self, text: str) -> None:
        """
        Update the thinking tag state based on text content.
239

240
241
242
243
244
245
        Args:
            text: Text to analyze for thinking tags
        """
        open_count = text.count("<think>")
        close_count = text.count("</think>")
        self.in_thinking_tag = open_count > close_count or (
246
247
            open_count == close_count and text.endswith("</think>")
        )
248
249
250
251

    def _is_potential_tag_start(self, text: str) -> bool:
        """
        Check if text might be the start of a tool call tag.
252

253
254
        Args:
            text: Text to check
255

256
257
258
259
260
        Returns:
            True if text could be the start of a tool call tag
        """
        for tag in [self.tool_call_start_token, self.tool_call_end_token]:
            if any(
261
262
263
                tag.startswith(text[-i:])
                for i in range(1, min(len(text) + 1, len(tag)))
            ):
264
265
266
267
268
269
                return True
        return False

    def _should_buffer_content(self, delta_text: str) -> bool:
        """
        Determine if content should be buffered for later processing.
270

271
272
        Args:
            delta_text: Delta text to check
273

274
275
276
277
278
        Returns:
            True if content should be buffered
        """
        if self.in_thinking_tag:
            return False
279
280
281
282
283
284
        return bool(
            self.pending_buffer
            or self.tool_call_start_token in delta_text
            or self.tool_call_end_token in delta_text
            or delta_text.startswith("<")
        )
285
286
287
288

    def _split_content_for_buffering(self, delta_text: str) -> tuple[str, str]:
        """
        Split delta text into safe content and potential tag content.
289

290
291
        Args:
            delta_text: Delta text to split
292

293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
        Returns:
            Tuple of (safe_content, potential_tag_content)
        """
        if self.in_thinking_tag:
            return delta_text, ""

        for tag in [self.tool_call_start_token, self.tool_call_end_token]:
            for i in range(1, len(tag)):
                tag_prefix = tag[:i]
                pos = delta_text.rfind(tag_prefix)
                if pos != -1 and tag.startswith(delta_text[pos:]):
                    return delta_text[:pos], delta_text[pos:]
        return delta_text, ""

    def _process_buffer(self, new_content: str) -> str:
        """
        Process buffered content and return output content.
310

311
312
        Args:
            new_content: New content to add to buffer
313

314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
        Returns:
            Processed output content
        """
        self.pending_buffer += new_content
        output_content = ""

        if self.in_thinking_tag:
            output_content = self.pending_buffer
            self.pending_buffer = ""
            return output_content

        while self.pending_buffer:
            start_pos = self.pending_buffer.find(self.tool_call_start_token)
            end_pos = self.pending_buffer.find(self.tool_call_end_token)

            if start_pos != -1 and (end_pos == -1 or start_pos < end_pos):
                tag_pos, tag_len = start_pos, len(self.tool_call_start_token)
            elif end_pos != -1:
                tag_pos, tag_len = end_pos, len(self.tool_call_end_token)
            else:
                if self._is_potential_tag_start(self.pending_buffer):
                    break
                output_content += self.pending_buffer
                self.pending_buffer = ""
                break

            output_content += self.pending_buffer[:tag_pos]
341
            self.pending_buffer = self.pending_buffer[tag_pos + tag_len :]
342
343
344
345
346
347
348
349
350
351
352
353
354

        return output_content

    def _reset_streaming_state(self) -> None:
        """Reset the streaming state to initial values."""
        self.streaming_state = {
            "current_tool_index": -1,
            "tool_ids": [],
            "sent_tools": [],
        }

    def _advance_to_next_tool(self) -> None:
        """Advance to the next tool in the streaming sequence."""
355
356
357
        self.streaming_state["current_tool_index"] = (
            int(self.streaming_state["current_tool_index"]) + 1
        )
358
359
360
361

    def _set_current_tool_index(self, index: int) -> None:
        """
        Set the current tool index.
362

363
364
365
366
367
368
369
370
        Args:
            index: Tool index to set
        """
        self.streaming_state["current_tool_index"] = index

    def _get_current_tool_index(self) -> int:
        """
        Get the current tool index.
371

372
373
374
375
376
377
378
379
        Returns:
            Current tool index
        """
        return int(self.streaming_state["current_tool_index"])

    def _get_next_unsent_tool_index(self, tool_count: int) -> int:
        """
        Get the index of the next unsent tool.
380

381
382
        Args:
            tool_count: Total number of tools
383

384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
        Returns:
            Index of next unsent tool, or -1 if all tools sent
        """
        sent_tools = list(self.streaming_state["sent_tools"])
        for i in range(tool_count):
            if i < len(sent_tools):
                if not sent_tools[i]["sent_name"]:
                    return i
            else:
                return i
        return -1

    def _ensure_state_arrays(self, tool_count: int) -> None:
        """
        Ensure state arrays have sufficient capacity for tool_count tools.
399

400
401
402
403
404
405
406
        Args:
            tool_count: Number of tools to prepare for
        """
        sent_tools = list(self.streaming_state["sent_tools"])
        tool_ids = list(self.streaming_state["tool_ids"])

        while len(sent_tools) < tool_count:
407
408
409
410
411
412
413
            sent_tools.append(
                {
                    "sent_name": False,
                    "sent_arguments": "",
                    "id": make_tool_call_id(),
                }
            )
414
415
416
417
418
419
420
421
422
423

        while len(tool_ids) < tool_count:
            tool_ids.append(None)

        self.streaming_state["sent_tools"] = sent_tools
        self.streaming_state["tool_ids"] = tool_ids

    def _detect_tools_in_text(self, text: str) -> int:
        """
        Detect the number of tools in text by counting name patterns.
424

425
426
        Args:
            text: Text to analyze
427

428
429
430
431
432
433
434
435
436
        Returns:
            Number of tools detected
        """
        matches = self.tool_name_pattern.findall(text)
        return len(matches)

    def _find_tool_boundaries(self, text: str) -> list[tuple[int, int]]:
        """
        Find the boundaries of tool calls in text.
437

438
439
        Args:
            text: Text to analyze
440

441
442
443
444
445
446
        Returns:
            List of (start, end) positions for tool calls
        """
        boundaries = []
        i = 0
        while i < len(text):
447
            if text[i] == "{":
448
449
450
451
452
453
                start = i
                depth = 0
                has_name = False
                has_arguments = False

                while i < len(text):
454
                    if text[i] == "{":
455
                        depth += 1
456
                    elif text[i] == "}":
457
458
459
460
461
462
463
464
                        depth -= 1
                        if depth == 0:
                            end = i + 1
                            segment = text[start:end]
                            if '"name"' in segment and '"arguments"' in segment:
                                boundaries.append((start, end))
                            break

465
                    if not has_name and '"name"' in text[start : i + 1]:
466
                        has_name = True
467
                    if not has_arguments and '"arguments"' in text[start : i + 1]:
468
469
470
471
472
473
474
475
476
477
                        has_arguments = True

                    i += 1

                if depth > 0 and has_name:
                    boundaries.append((start, i))
            else:
                i += 1
        return boundaries

478
    def _extract_tool_args(self, tool_content: str, args_match: re.Match[str]) -> str:
479
480
        """
        Extract tool arguments from tool content.
481

482
483
484
        Args:
            tool_content: Tool call content
            args_match: Regex match for arguments pattern
485

486
487
488
489
490
491
        Returns:
            Extracted arguments as string
        """
        args_start_pos = args_match.end()
        remaining_content = tool_content[args_start_pos:]

492
        if remaining_content.strip().startswith("{"):
493
494
            depth = 0
            for i, char in enumerate(remaining_content):
495
                if char == "{":
496
                    depth += 1
497
                elif char == "}":
498
499
                    depth -= 1
                    if depth == 0:
500
                        return remaining_content[: i + 1]
501
        else:
502
            args_end = remaining_content.find("}")
503
504
505
            if args_end > 0:
                return remaining_content[:args_end].strip()

506
        return remaining_content.rstrip("}").strip()
507
508

    def _get_current_tool_content(
509
        self, text: str, tool_index: int
510
    ) -> tuple[str | None, str | None]:
511
512
        """
        Get the content of a specific tool by index.
513

514
515
516
        Args:
            text: Text containing tool calls
            tool_index: Index of tool to extract
517

518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
        Returns:
            Tuple of (tool_name, tool_arguments) or (None, None) if not found
        """
        boundaries = self._find_tool_boundaries(text)

        if tool_index >= len(boundaries):
            return None, None

        start, end = boundaries[tool_index]
        tool_content = text[start:end]

        name_match = self.tool_name_pattern.search(tool_content)
        name = name_match.group(1) if name_match else None

        args_match = self.tool_args_pattern.search(tool_content)
        if args_match:
            try:
                args_text = self._extract_tool_args(tool_content, args_match)
                return name, args_text
            except Exception:
538
539
                remaining_content = tool_content[args_match.end() :]
                args_text = remaining_content.rstrip("}").strip()
540
541
542
543
544
                return name, args_text

        return name, None

    def _handle_tool_name_streaming(
545
        self, tool_content: str, tool_count: int
546
    ) -> DeltaMessage | None:
547
548
        """
        Handle streaming of tool names.
549

550
551
552
        Args:
            tool_content: Content containing tool calls
            tool_count: Total number of tools
553

554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
        Returns:
            DeltaMessage with tool name or None if no tool to stream
        """
        next_idx = self._get_next_unsent_tool_index(tool_count)

        if next_idx == -1:
            return None

        boundaries = self._find_tool_boundaries(tool_content)
        if next_idx >= len(boundaries):
            return None

        tool_name, _ = self._get_current_tool_content(tool_content, next_idx)
        if not tool_name:
            return None

        self._set_current_tool_index(next_idx)
        sent_tools = list(self.streaming_state["sent_tools"])
        tool_ids = list(self.streaming_state["tool_ids"])

        tool_id = sent_tools[next_idx]["id"]
        tool_ids[next_idx] = tool_id
        sent_tools[next_idx]["sent_name"] = True

        self.streaming_state["sent_tools"] = sent_tools
        self.streaming_state["tool_ids"] = tool_ids

581
582
583
584
585
586
587
588
589
590
591
592
        return DeltaMessage(
            tool_calls=[
                DeltaToolCall(
                    index=next_idx,
                    type="function",
                    id=tool_id,
                    function=DeltaFunctionCall(name=tool_name).model_dump(
                        exclude_none=True
                    ),
                )
            ]
        )
593
594

    def _handle_tool_args_streaming(
595
        self, tool_content: str, tool_count: int
596
    ) -> DeltaMessage | None:
597
598
        """
        Handle streaming of tool arguments.
599

600
601
602
        Args:
            tool_content: Content containing tool calls
            tool_count: Total number of tools
603

604
605
606
607
608
609
610
611
        Returns:
            DeltaMessage with tool arguments or None if no arguments to stream
        """
        current_idx = self._get_current_tool_index()

        if current_idx < 0 or current_idx >= tool_count:
            return None

612
        tool_name, tool_args = self._get_current_tool_content(tool_content, current_idx)
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
        if not tool_name or tool_args is None:
            return None

        sent_tools = list(self.streaming_state["sent_tools"])

        if not sent_tools[current_idx]["sent_name"]:
            return None

        clean_args = self._clean_duplicate_braces(tool_args)
        sent_args = sent_tools[current_idx]["sent_arguments"]

        if clean_args != sent_args:
            if sent_args and clean_args.startswith(sent_args):
                args_delta = extract_intermediate_diff(clean_args, sent_args)
                if args_delta:
                    args_delta = self._clean_delta_braces(args_delta)
                    sent_tools[current_idx]["sent_arguments"] = clean_args
                    self.streaming_state["sent_tools"] = sent_tools

632
                    if clean_args.endswith("}"):
633
634
                        self._advance_to_next_tool()

635
636
637
638
639
640
641
642
643
644
                    return DeltaMessage(
                        tool_calls=[
                            DeltaToolCall(
                                index=current_idx,
                                function=DeltaFunctionCall(
                                    arguments=args_delta
                                ).model_dump(exclude_none=True),
                            )
                        ]
                    )
645
646
647
648
649
            elif not sent_args and clean_args:
                clean_args_delta = self._clean_delta_braces(clean_args)
                sent_tools[current_idx]["sent_arguments"] = clean_args
                self.streaming_state["sent_tools"] = sent_tools

650
                if clean_args.endswith("}"):
651
652
                    self._advance_to_next_tool()

653
654
655
656
657
658
659
660
661
662
                return DeltaMessage(
                    tool_calls=[
                        DeltaToolCall(
                            index=current_idx,
                            function=DeltaFunctionCall(
                                arguments=clean_args_delta
                            ).model_dump(exclude_none=True),
                        )
                    ]
                )
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679

        return None

    def _is_end_tool_calls(self, current_text: str) -> bool:
        if self.tool_call_end_token not in current_text:
            return False

        end_token_positions = []
        search_start = 0
        while True:
            pos = current_text.find(self.tool_call_end_token, search_start)
            if pos == -1:
                break
            end_token_positions.append(pos)
            search_start = pos + 1

        think_regions = []
680
681
682
        for match in re.finditer(
            self.thinking_tag_pattern, current_text, flags=re.DOTALL
        ):
683
684
685
            think_regions.append((match.start(), match.end()))

        for pos in end_token_positions:
686
687
688
            in_think = any(
                pos >= t_start and pos < t_end for t_start, t_end in think_regions
            )
689
690
691
692
693
            if not in_think:
                return True

        return False

694
695
696
697
698
699
700
701
702
    def extract_tool_calls_streaming(
        self,
        previous_text: str,
        current_text: str,
        delta_text: str,
        previous_token_ids: Sequence[int],
        current_token_ids: Sequence[int],
        delta_token_ids: Sequence[int],
        request: ChatCompletionRequest,
703
    ) -> DeltaMessage | None:
704
705
706
707
708
709
710
        self._update_thinking_state(current_text)

        if self.in_thinking_tag:
            return DeltaMessage(content=delta_text)

        if self._should_buffer_content(delta_text):
            buffered_output = self._process_buffer(delta_text)
711
            return DeltaMessage(content=buffered_output) if buffered_output else None
712
713
714
715

        if self._is_end_tool_calls(current_text):
            return DeltaMessage(content=delta_text)

716
        safe_content, potential_tag = self._split_content_for_buffering(delta_text)
717
718
719
        if potential_tag:
            self.pending_buffer += potential_tag
            return DeltaMessage(content=safe_content) if safe_content else None
720
721
722
723

        processed_current_text = self.preprocess_model_output(current_text)

        if self.tool_call_start_token not in processed_current_text:
724
725
726
727
            if (
                self.tool_call_end_token in delta_text
                and self.tool_call_start_token in current_text
            ):
728
                return None
729
            if delta_text.strip() == "" and self.tool_call_start_token in current_text:
730
                return None
731
732
733
734
            if (
                self._get_current_tool_index() != -1
                and self.tool_call_end_token in current_text
            ):
735
                self._reset_streaming_state()
736
737
            return DeltaMessage(content=delta_text)

738
739
740
741
742
        if (
            self.tool_call_start_token_id is not None
            and self.tool_call_start_token_id in delta_token_ids
            and len(delta_token_ids) == 1
        ):
743
744
            return None

745
        original_tool_start = self._find_tool_start_outside_thinking(current_text)
746
747
        if original_tool_start is None:
            return None
748

749
        content_before_tools = self._extract_content_before_tools(
750
751
            current_text, delta_text, original_tool_start
        )
752
753
        if content_before_tools:
            return DeltaMessage(content=content_before_tools)
754
755

        try:
756
            tool_content = self._extract_tool_content(current_text, original_tool_start)
757
758
759
            current_tools_count = self._detect_tools_in_text(tool_content)

            if current_tools_count == 0:
760
761
                return None

762
763
            if self._get_current_tool_index() == -1:
                self._reset_streaming_state()
764

765
            self._ensure_state_arrays(current_tools_count)
766

767
768
769
            return self._handle_tool_name_streaming(
                tool_content, current_tools_count
            ) or self._handle_tool_args_streaming(tool_content, current_tools_count)
770
771

        except Exception:
772
773
774
            logger.exception(
                "An unexpected error occurred ", "during streaming tool call handling."
            )
775
            return None
776

777
    def _find_tool_start_outside_thinking(self, current_text: str) -> int | None:
778
779
        """
        Find the start position of tool calls outside of thinking tags.
780

781
782
        Args:
            current_text: Current text to search
783

784
785
786
787
788
789
790
791
792
        Returns:
            Position of tool call start or None if not found
        """
        search_start = 0
        while True:
            pos = current_text.find(self.tool_call_start_token, search_start)
            if pos == -1:
                return None

793
794
795
796
797
798
799
800
801
            think_regions = [
                (m.start(), m.end())
                for m in re.finditer(
                    r"<think>(.*?)</think>", current_text, flags=re.DOTALL
                )
            ]
            in_think = any(
                pos >= t_start and pos < t_end for t_start, t_end in think_regions
            )
802
803
804
805
806
807

            if not in_think:
                return pos

            search_start = pos + 1

808
809
    def _extract_content_before_tools(
        self, current_text: str, delta_text: str, tool_start: int
810
    ) -> str | None:
811
812
        """
        Extract content that appears before tool calls.
813

814
815
816
817
        Args:
            current_text: Current text
            delta_text: Delta text
            tool_start: Start position of tools
818

819
820
821
822
823
824
825
826
        Returns:
            Content before tools or None
        """
        if tool_start > 0:
            delta_start_pos = len(current_text) - len(delta_text)
            if delta_start_pos < tool_start:
                content_part = delta_text
                if delta_start_pos + len(delta_text) > tool_start:
827
                    content_part = delta_text[: tool_start - delta_start_pos]
828
829
830
831
832
833
                return content_part if content_part else None
        return None

    def _extract_tool_content(self, current_text: str, tool_start: int) -> str:
        """
        Extract tool content from current text starting at tool_start.
834

835
836
837
        Args:
            current_text: Current text
            tool_start: Start position of tool calls
838

839
840
841
842
843
844
845
846
847
848
849
        Returns:
            Extracted tool content
        """
        tool_content_start = tool_start + len(self.tool_call_start_token)
        tool_content = current_text[tool_content_start:]

        end_pos = tool_content.find(self.tool_call_end_token)
        if end_pos != -1:
            tool_content = tool_content[:end_pos]

        return tool_content