minimax_tool_parser.py 27.9 KB
Newer Older
1
2
3
4
5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import json
from collections.abc import Sequence
6
from typing import Any
7
8
9

import regex as re

10
from vllm.entrypoints.chat_utils import make_tool_call_id
11
from vllm.entrypoints.openai.chat_completion.protocol import (
12
    ChatCompletionRequest,
13
14
)
from vllm.entrypoints.openai.engine.protocol import (
15
16
17
18
19
20
21
    DeltaFunctionCall,
    DeltaMessage,
    DeltaToolCall,
    ExtractedToolCallInformation,
    FunctionCall,
    ToolCall,
)
22
from vllm.logger import init_logger
23
from vllm.tokenizers import TokenizerLike
24
25
26
27
from vllm.tool_parsers.abstract_tool_parser import (
    ToolParser,
)
from vllm.tool_parsers.utils import extract_intermediate_diff
28
29
30
31
32

logger = init_logger(__name__)


class MinimaxToolParser(ToolParser):
33
    def __init__(self, tokenizer: TokenizerLike):
34
35
        super().__init__(tokenizer)

36
37
38
39
40
41
        # Initialize streaming state for tracking tool call progress
        self.streaming_state: dict[str, Any] = {
            "current_tool_index": -1,  # Index of current tool being processed
            "tool_ids": [],  # List of tool call IDs
            "sent_tools": [],  # List of tools that have been sent
        }
42

43
44
45
        # Define tool call tokens and patterns
        self.tool_call_start_token = "<tool_calls>"
        self.tool_call_end_token = "</tool_calls>"
46
        self.tool_call_regex = re.compile(
47
48
            r"<tool_calls>(.*?)</tool_calls>|<tool_calls>(.*)", re.DOTALL
        )
49
        self.thinking_tag_pattern = r"<think>(.*?)</think>"
50
51
52
53
54
55
        self.tool_name_pattern = re.compile(r'"name":\s*"([^"]+)"')
        self.tool_args_pattern = re.compile(r'"arguments":\s*')

        # Buffer for handling partial tool calls during streaming
        self.pending_buffer = ""
        self.in_thinking_tag = False
56
57
58
59

        if not self.model_tokenizer:
            raise ValueError(
                "The model tokenizer must be passed to the ToolParser "
60
61
                "constructor during construction."
            )
62

63
        # Get token IDs for tool call start/end tokens
64
        self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token)
65
66
        self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)

67
        if self.tool_call_start_token_id is None or self.tool_call_end_token_id is None:
68
69
            logger.warning(
                "Minimax Tool parser could not locate tool call start/end "
70
71
                "tokens in the tokenizer. Falling back to string matching."
            )
72
73
74

    def preprocess_model_output(self, model_output: str) -> str:
        """
75
        Preprocess model output by removing tool calls from thinking tags.
76

77
78
        Args:
            model_output: Raw model output string
79

80
81
        Returns:
            Preprocessed model output with tool calls removed from thinking tags
82
83
84
85
        """

        def remove_tool_calls_from_think(match):
            think_content = match.group(1)
86
87
88
            cleaned_content = re.sub(
                r"<tool_calls>.*?</tool_calls>", "", think_content, flags=re.DOTALL
            )
89
90
            return f"<think>{cleaned_content}</think>"

91
92
93
94
95
96
        return re.sub(
            self.thinking_tag_pattern,
            remove_tool_calls_from_think,
            model_output,
            flags=re.DOTALL,
        )
97

98
99
100
    def _clean_duplicate_braces(self, args_text: str) -> str:
        """
        Clean duplicate closing braces from arguments text.
101

102
103
        Args:
            args_text: Raw arguments text
104

105
106
107
108
109
110
111
112
113
114
115
116
117
        Returns:
            Cleaned arguments text with proper JSON formatting
        """
        args_text = args_text.strip()
        if not args_text:
            return args_text

        try:
            json.loads(args_text)
            return args_text
        except json.JSONDecodeError:
            pass

118
        while args_text.endswith("}}"):
119
120
121
122
123
124
125
126
127
128
129
130
            candidate = args_text[:-1]
            try:
                json.loads(candidate)
                return candidate
            except json.JSONDecodeError:
                args_text = candidate

        return args_text

    def _clean_delta_braces(self, delta_text: str) -> str:
        """
        Clean delta text by removing excessive closing braces.
131

132
133
        Args:
            delta_text: Delta text to clean
134

135
136
137
138
139
140
141
142
        Returns:
            Cleaned delta text
        """
        if not delta_text:
            return delta_text

        delta_stripped = delta_text.strip()

143
144
        if delta_stripped and all(c in "}\n\r\t " for c in delta_stripped):
            brace_count = delta_stripped.count("}")
145
            if brace_count > 1:
146
                return "}\n" if delta_text.endswith("\n") else "}"
147
148

        return delta_text
149
150
151
152
153
154

    def extract_tool_calls(
        self,
        model_output: str,
        request: ChatCompletionRequest,
    ) -> ExtractedToolCallInformation:
155
156
        """
        Extract tool calls from model output for non-streaming mode.
157

158
159
160
        Args:
            model_output: Complete model output
            request: Chat completion request
161

162
163
164
        Returns:
            ExtractedToolCallInformation containing tool calls and content
        """
165
166
167
        processed_output = self.preprocess_model_output(model_output)

        if self.tool_call_start_token not in processed_output:
168
169
170
            return ExtractedToolCallInformation(
                tools_called=False, tool_calls=[], content=model_output
            )
171
172

        try:
173
            function_call_tuples = self.tool_call_regex.findall(processed_output)
174
175
176
177
178

            raw_function_calls = []
            for match in function_call_tuples:
                tool_call_content = match[0] if match[0] else match[1]
                if tool_call_content.strip():
179
                    lines = tool_call_content.strip().split("\n")
180
181
                    for line in lines:
                        line = line.strip()
182
                        if line and line.startswith("{") and line.endswith("}"):
183
184
185
186
187
188
189
190
191
192
                            try:
                                parsed_call = json.loads(line)
                                raw_function_calls.append(parsed_call)
                            except json.JSONDecodeError:
                                continue

            tool_calls = []
            for function_call in raw_function_calls:
                if "name" in function_call and "arguments" in function_call:
                    tool_calls.append(
193
194
195
196
197
198
199
200
201
202
                        ToolCall(
                            type="function",
                            function=FunctionCall(
                                name=function_call["name"],
                                arguments=json.dumps(
                                    function_call["arguments"], ensure_ascii=False
                                ),
                            ),
                        )
                    )
203
204
205
206
207
208

            processed_pos = processed_output.find(self.tool_call_start_token)
            if processed_pos != -1:
                processed_content = processed_output[:processed_pos].strip()

                if processed_content:
209
                    lines = processed_content.split("\n")
210
211
212
213
214
                    for line in reversed(lines):
                        line = line.strip()
                        if line:
                            pos = model_output.find(line)
                            if pos != -1:
215
                                content = model_output[: pos + len(line)]
216
217
218
219
220
221
222
223
224
225
226
                                break
                    else:
                        content = ""
                else:
                    content = ""
            else:
                content = model_output

            return ExtractedToolCallInformation(
                tools_called=len(tool_calls) > 0,
                tool_calls=tool_calls,
227
228
                content=content.strip() if content.strip() else None,
            )
229
230
231

        except Exception:
            logger.exception(
232
233
234
235
236
                "An unexpected error occurred during tool call extraction."
            )
            return ExtractedToolCallInformation(
                tools_called=False, tool_calls=[], content=model_output
            )
237

238
239
240
    def _update_thinking_state(self, text: str) -> None:
        """
        Update the thinking tag state based on text content.
241

242
243
244
245
246
247
        Args:
            text: Text to analyze for thinking tags
        """
        open_count = text.count("<think>")
        close_count = text.count("</think>")
        self.in_thinking_tag = open_count > close_count or (
248
249
            open_count == close_count and text.endswith("</think>")
        )
250
251
252
253

    def _is_potential_tag_start(self, text: str) -> bool:
        """
        Check if text might be the start of a tool call tag.
254

255
256
        Args:
            text: Text to check
257

258
259
260
261
262
        Returns:
            True if text could be the start of a tool call tag
        """
        for tag in [self.tool_call_start_token, self.tool_call_end_token]:
            if any(
263
264
265
                tag.startswith(text[-i:])
                for i in range(1, min(len(text) + 1, len(tag)))
            ):
266
267
268
269
270
271
                return True
        return False

    def _should_buffer_content(self, delta_text: str) -> bool:
        """
        Determine if content should be buffered for later processing.
272

273
274
        Args:
            delta_text: Delta text to check
275

276
277
278
279
280
        Returns:
            True if content should be buffered
        """
        if self.in_thinking_tag:
            return False
281
282
283
284
285
286
        return bool(
            self.pending_buffer
            or self.tool_call_start_token in delta_text
            or self.tool_call_end_token in delta_text
            or delta_text.startswith("<")
        )
287
288
289
290

    def _split_content_for_buffering(self, delta_text: str) -> tuple[str, str]:
        """
        Split delta text into safe content and potential tag content.
291

292
293
        Args:
            delta_text: Delta text to split
294

295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
        Returns:
            Tuple of (safe_content, potential_tag_content)
        """
        if self.in_thinking_tag:
            return delta_text, ""

        for tag in [self.tool_call_start_token, self.tool_call_end_token]:
            for i in range(1, len(tag)):
                tag_prefix = tag[:i]
                pos = delta_text.rfind(tag_prefix)
                if pos != -1 and tag.startswith(delta_text[pos:]):
                    return delta_text[:pos], delta_text[pos:]
        return delta_text, ""

    def _process_buffer(self, new_content: str) -> str:
        """
        Process buffered content and return output content.
312

313
314
        Args:
            new_content: New content to add to buffer
315

316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
        Returns:
            Processed output content
        """
        self.pending_buffer += new_content
        output_content = ""

        if self.in_thinking_tag:
            output_content = self.pending_buffer
            self.pending_buffer = ""
            return output_content

        while self.pending_buffer:
            start_pos = self.pending_buffer.find(self.tool_call_start_token)
            end_pos = self.pending_buffer.find(self.tool_call_end_token)

            if start_pos != -1 and (end_pos == -1 or start_pos < end_pos):
                tag_pos, tag_len = start_pos, len(self.tool_call_start_token)
            elif end_pos != -1:
                tag_pos, tag_len = end_pos, len(self.tool_call_end_token)
            else:
                if self._is_potential_tag_start(self.pending_buffer):
                    break
                output_content += self.pending_buffer
                self.pending_buffer = ""
                break

            output_content += self.pending_buffer[:tag_pos]
343
            self.pending_buffer = self.pending_buffer[tag_pos + tag_len :]
344
345
346
347
348
349
350
351
352
353
354
355
356

        return output_content

    def _reset_streaming_state(self) -> None:
        """Reset the streaming state to initial values."""
        self.streaming_state = {
            "current_tool_index": -1,
            "tool_ids": [],
            "sent_tools": [],
        }

    def _advance_to_next_tool(self) -> None:
        """Advance to the next tool in the streaming sequence."""
357
358
359
        self.streaming_state["current_tool_index"] = (
            int(self.streaming_state["current_tool_index"]) + 1
        )
360
361
362
363

    def _set_current_tool_index(self, index: int) -> None:
        """
        Set the current tool index.
364

365
366
367
368
369
370
371
372
        Args:
            index: Tool index to set
        """
        self.streaming_state["current_tool_index"] = index

    def _get_current_tool_index(self) -> int:
        """
        Get the current tool index.
373

374
375
376
377
378
379
380
381
        Returns:
            Current tool index
        """
        return int(self.streaming_state["current_tool_index"])

    def _get_next_unsent_tool_index(self, tool_count: int) -> int:
        """
        Get the index of the next unsent tool.
382

383
384
        Args:
            tool_count: Total number of tools
385

386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
        Returns:
            Index of next unsent tool, or -1 if all tools sent
        """
        sent_tools = list(self.streaming_state["sent_tools"])
        for i in range(tool_count):
            if i < len(sent_tools):
                if not sent_tools[i]["sent_name"]:
                    return i
            else:
                return i
        return -1

    def _ensure_state_arrays(self, tool_count: int) -> None:
        """
        Ensure state arrays have sufficient capacity for tool_count tools.
401

402
403
404
405
406
407
408
        Args:
            tool_count: Number of tools to prepare for
        """
        sent_tools = list(self.streaming_state["sent_tools"])
        tool_ids = list(self.streaming_state["tool_ids"])

        while len(sent_tools) < tool_count:
409
410
411
412
413
414
415
            sent_tools.append(
                {
                    "sent_name": False,
                    "sent_arguments": "",
                    "id": make_tool_call_id(),
                }
            )
416
417
418
419
420
421
422
423
424
425

        while len(tool_ids) < tool_count:
            tool_ids.append(None)

        self.streaming_state["sent_tools"] = sent_tools
        self.streaming_state["tool_ids"] = tool_ids

    def _detect_tools_in_text(self, text: str) -> int:
        """
        Detect the number of tools in text by counting name patterns.
426

427
428
        Args:
            text: Text to analyze
429

430
431
432
433
434
435
436
437
438
        Returns:
            Number of tools detected
        """
        matches = self.tool_name_pattern.findall(text)
        return len(matches)

    def _find_tool_boundaries(self, text: str) -> list[tuple[int, int]]:
        """
        Find the boundaries of tool calls in text.
439

440
441
        Args:
            text: Text to analyze
442

443
444
445
446
447
448
        Returns:
            List of (start, end) positions for tool calls
        """
        boundaries = []
        i = 0
        while i < len(text):
449
            if text[i] == "{":
450
451
452
453
454
455
                start = i
                depth = 0
                has_name = False
                has_arguments = False

                while i < len(text):
456
                    if text[i] == "{":
457
                        depth += 1
458
                    elif text[i] == "}":
459
460
461
462
463
464
465
466
                        depth -= 1
                        if depth == 0:
                            end = i + 1
                            segment = text[start:end]
                            if '"name"' in segment and '"arguments"' in segment:
                                boundaries.append((start, end))
                            break

467
                    if not has_name and '"name"' in text[start : i + 1]:
468
                        has_name = True
469
                    if not has_arguments and '"arguments"' in text[start : i + 1]:
470
471
472
473
474
475
476
477
478
479
                        has_arguments = True

                    i += 1

                if depth > 0 and has_name:
                    boundaries.append((start, i))
            else:
                i += 1
        return boundaries

480
    def _extract_tool_args(self, tool_content: str, args_match: re.Match[str]) -> str:
481
482
        """
        Extract tool arguments from tool content.
483

484
485
486
        Args:
            tool_content: Tool call content
            args_match: Regex match for arguments pattern
487

488
489
490
491
492
493
        Returns:
            Extracted arguments as string
        """
        args_start_pos = args_match.end()
        remaining_content = tool_content[args_start_pos:]

494
        if remaining_content.strip().startswith("{"):
495
496
            depth = 0
            for i, char in enumerate(remaining_content):
497
                if char == "{":
498
                    depth += 1
499
                elif char == "}":
500
501
                    depth -= 1
                    if depth == 0:
502
                        return remaining_content[: i + 1]
503
        else:
504
            args_end = remaining_content.find("}")
505
506
507
            if args_end > 0:
                return remaining_content[:args_end].strip()

508
        return remaining_content.rstrip("}").strip()
509
510

    def _get_current_tool_content(
511
        self, text: str, tool_index: int
512
    ) -> tuple[str | None, str | None]:
513
514
        """
        Get the content of a specific tool by index.
515

516
517
518
        Args:
            text: Text containing tool calls
            tool_index: Index of tool to extract
519

520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
        Returns:
            Tuple of (tool_name, tool_arguments) or (None, None) if not found
        """
        boundaries = self._find_tool_boundaries(text)

        if tool_index >= len(boundaries):
            return None, None

        start, end = boundaries[tool_index]
        tool_content = text[start:end]

        name_match = self.tool_name_pattern.search(tool_content)
        name = name_match.group(1) if name_match else None

        args_match = self.tool_args_pattern.search(tool_content)
        if args_match:
            try:
                args_text = self._extract_tool_args(tool_content, args_match)
                return name, args_text
            except Exception:
540
541
                remaining_content = tool_content[args_match.end() :]
                args_text = remaining_content.rstrip("}").strip()
542
543
544
545
546
                return name, args_text

        return name, None

    def _handle_tool_name_streaming(
547
        self, tool_content: str, tool_count: int
548
    ) -> DeltaMessage | None:
549
550
        """
        Handle streaming of tool names.
551

552
553
554
        Args:
            tool_content: Content containing tool calls
            tool_count: Total number of tools
555

556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
        Returns:
            DeltaMessage with tool name or None if no tool to stream
        """
        next_idx = self._get_next_unsent_tool_index(tool_count)

        if next_idx == -1:
            return None

        boundaries = self._find_tool_boundaries(tool_content)
        if next_idx >= len(boundaries):
            return None

        tool_name, _ = self._get_current_tool_content(tool_content, next_idx)
        if not tool_name:
            return None

        self._set_current_tool_index(next_idx)
        sent_tools = list(self.streaming_state["sent_tools"])
        tool_ids = list(self.streaming_state["tool_ids"])

        tool_id = sent_tools[next_idx]["id"]
        tool_ids[next_idx] = tool_id
        sent_tools[next_idx]["sent_name"] = True

        self.streaming_state["sent_tools"] = sent_tools
        self.streaming_state["tool_ids"] = tool_ids

583
584
585
586
587
588
589
590
591
592
593
594
        return DeltaMessage(
            tool_calls=[
                DeltaToolCall(
                    index=next_idx,
                    type="function",
                    id=tool_id,
                    function=DeltaFunctionCall(name=tool_name).model_dump(
                        exclude_none=True
                    ),
                )
            ]
        )
595
596

    def _handle_tool_args_streaming(
597
        self, tool_content: str, tool_count: int
598
    ) -> DeltaMessage | None:
599
600
        """
        Handle streaming of tool arguments.
601

602
603
604
        Args:
            tool_content: Content containing tool calls
            tool_count: Total number of tools
605

606
607
608
609
610
611
612
613
        Returns:
            DeltaMessage with tool arguments or None if no arguments to stream
        """
        current_idx = self._get_current_tool_index()

        if current_idx < 0 or current_idx >= tool_count:
            return None

614
        tool_name, tool_args = self._get_current_tool_content(tool_content, current_idx)
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
        if not tool_name or tool_args is None:
            return None

        sent_tools = list(self.streaming_state["sent_tools"])

        if not sent_tools[current_idx]["sent_name"]:
            return None

        clean_args = self._clean_duplicate_braces(tool_args)
        sent_args = sent_tools[current_idx]["sent_arguments"]

        if clean_args != sent_args:
            if sent_args and clean_args.startswith(sent_args):
                args_delta = extract_intermediate_diff(clean_args, sent_args)
                if args_delta:
                    args_delta = self._clean_delta_braces(args_delta)
                    sent_tools[current_idx]["sent_arguments"] = clean_args
                    self.streaming_state["sent_tools"] = sent_tools

634
                    if clean_args.endswith("}"):
635
636
                        self._advance_to_next_tool()

637
638
639
640
641
642
643
644
645
646
                    return DeltaMessage(
                        tool_calls=[
                            DeltaToolCall(
                                index=current_idx,
                                function=DeltaFunctionCall(
                                    arguments=args_delta
                                ).model_dump(exclude_none=True),
                            )
                        ]
                    )
647
648
649
650
651
            elif not sent_args and clean_args:
                clean_args_delta = self._clean_delta_braces(clean_args)
                sent_tools[current_idx]["sent_arguments"] = clean_args
                self.streaming_state["sent_tools"] = sent_tools

652
                if clean_args.endswith("}"):
653
654
                    self._advance_to_next_tool()

655
656
657
658
659
660
661
662
663
664
                return DeltaMessage(
                    tool_calls=[
                        DeltaToolCall(
                            index=current_idx,
                            function=DeltaFunctionCall(
                                arguments=clean_args_delta
                            ).model_dump(exclude_none=True),
                        )
                    ]
                )
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681

        return None

    def _is_end_tool_calls(self, current_text: str) -> bool:
        if self.tool_call_end_token not in current_text:
            return False

        end_token_positions = []
        search_start = 0
        while True:
            pos = current_text.find(self.tool_call_end_token, search_start)
            if pos == -1:
                break
            end_token_positions.append(pos)
            search_start = pos + 1

        think_regions = []
682
683
684
        for match in re.finditer(
            self.thinking_tag_pattern, current_text, flags=re.DOTALL
        ):
685
686
687
            think_regions.append((match.start(), match.end()))

        for pos in end_token_positions:
688
689
690
            in_think = any(
                pos >= t_start and pos < t_end for t_start, t_end in think_regions
            )
691
692
693
694
695
            if not in_think:
                return True

        return False

696
697
698
699
700
701
702
703
704
    def extract_tool_calls_streaming(
        self,
        previous_text: str,
        current_text: str,
        delta_text: str,
        previous_token_ids: Sequence[int],
        current_token_ids: Sequence[int],
        delta_token_ids: Sequence[int],
        request: ChatCompletionRequest,
705
    ) -> DeltaMessage | None:
706
707
708
709
710
711
712
        self._update_thinking_state(current_text)

        if self.in_thinking_tag:
            return DeltaMessage(content=delta_text)

        if self._should_buffer_content(delta_text):
            buffered_output = self._process_buffer(delta_text)
713
            return DeltaMessage(content=buffered_output) if buffered_output else None
714
715
716
717

        if self._is_end_tool_calls(current_text):
            return DeltaMessage(content=delta_text)

718
        safe_content, potential_tag = self._split_content_for_buffering(delta_text)
719
720
721
        if potential_tag:
            self.pending_buffer += potential_tag
            return DeltaMessage(content=safe_content) if safe_content else None
722
723
724
725

        processed_current_text = self.preprocess_model_output(current_text)

        if self.tool_call_start_token not in processed_current_text:
726
727
728
729
            if (
                self.tool_call_end_token in delta_text
                and self.tool_call_start_token in current_text
            ):
730
                return None
731
            if delta_text.strip() == "" and self.tool_call_start_token in current_text:
732
                return None
733
734
735
736
            if (
                self._get_current_tool_index() != -1
                and self.tool_call_end_token in current_text
            ):
737
                self._reset_streaming_state()
738
739
            return DeltaMessage(content=delta_text)

740
741
742
743
744
        if (
            self.tool_call_start_token_id is not None
            and self.tool_call_start_token_id in delta_token_ids
            and len(delta_token_ids) == 1
        ):
745
746
            return None

747
        original_tool_start = self._find_tool_start_outside_thinking(current_text)
748
749
        if original_tool_start is None:
            return None
750

751
        content_before_tools = self._extract_content_before_tools(
752
753
            current_text, delta_text, original_tool_start
        )
754
755
        if content_before_tools:
            return DeltaMessage(content=content_before_tools)
756
757

        try:
758
            tool_content = self._extract_tool_content(current_text, original_tool_start)
759
760
761
            current_tools_count = self._detect_tools_in_text(tool_content)

            if current_tools_count == 0:
762
763
                return None

764
765
            if self._get_current_tool_index() == -1:
                self._reset_streaming_state()
766

767
            self._ensure_state_arrays(current_tools_count)
768

769
770
771
            return self._handle_tool_name_streaming(
                tool_content, current_tools_count
            ) or self._handle_tool_args_streaming(tool_content, current_tools_count)
772
773

        except Exception:
774
775
776
            logger.exception(
                "An unexpected error occurred ", "during streaming tool call handling."
            )
777
            return None
778

779
    def _find_tool_start_outside_thinking(self, current_text: str) -> int | None:
780
781
        """
        Find the start position of tool calls outside of thinking tags.
782

783
784
        Args:
            current_text: Current text to search
785

786
787
788
789
790
791
792
793
794
        Returns:
            Position of tool call start or None if not found
        """
        search_start = 0
        while True:
            pos = current_text.find(self.tool_call_start_token, search_start)
            if pos == -1:
                return None

795
796
797
798
799
800
801
802
803
            think_regions = [
                (m.start(), m.end())
                for m in re.finditer(
                    r"<think>(.*?)</think>", current_text, flags=re.DOTALL
                )
            ]
            in_think = any(
                pos >= t_start and pos < t_end for t_start, t_end in think_regions
            )
804
805
806
807
808
809

            if not in_think:
                return pos

            search_start = pos + 1

810
811
    def _extract_content_before_tools(
        self, current_text: str, delta_text: str, tool_start: int
812
    ) -> str | None:
813
814
        """
        Extract content that appears before tool calls.
815

816
817
818
819
        Args:
            current_text: Current text
            delta_text: Delta text
            tool_start: Start position of tools
820

821
822
823
824
825
826
827
828
        Returns:
            Content before tools or None
        """
        if tool_start > 0:
            delta_start_pos = len(current_text) - len(delta_text)
            if delta_start_pos < tool_start:
                content_part = delta_text
                if delta_start_pos + len(delta_text) > tool_start:
829
                    content_part = delta_text[: tool_start - delta_start_pos]
830
831
832
833
834
835
                return content_part if content_part else None
        return None

    def _extract_tool_content(self, current_text: str, tool_start: int) -> str:
        """
        Extract tool content from current text starting at tool_start.
836

837
838
839
        Args:
            current_text: Current text
            tool_start: Start position of tool calls
840

841
842
843
844
845
846
847
848
849
850
851
        Returns:
            Extracted tool content
        """
        tool_content_start = tool_start + len(self.tool_call_start_token)
        tool_content = current_text[tool_content_start:]

        end_pos = tool_content.find(self.tool_call_end_token)
        if end_pos != -1:
            tool_content = tool_content[:end_pos]

        return tool_content