abs_reasoning_parsers.py 11.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import importlib
5
import os
6
from abc import abstractmethod
7
from collections.abc import Callable, Iterable, Sequence
8
from functools import cached_property
9
from typing import TYPE_CHECKING, cast
10

11
from vllm.entrypoints.mcp.tool_server import ToolServer
12
from vllm.logger import init_logger
13
from vllm.utils.collection_utils import is_list_of
14
from vllm.utils.import_utils import import_from_path
15

16
if TYPE_CHECKING:
17
18
19
    from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
    from vllm.entrypoints.openai.engine.protocol import DeltaMessage
    from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
20
    from vllm.tokenizers import TokenizerLike
21

22
23
24
25
26
logger = init_logger(__name__)


class ReasoningParser:
    """
27
    Abstract reasoning parser class that should not be used directly.
28
29
30
31
32
    Provided and methods should be used in derived classes.

    It is used to extract reasoning content from the model output.
    """

33
    def __init__(self, tokenizer: "TokenizerLike", *args, **kwargs):
34
35
36
        self.model_tokenizer = tokenizer

    @cached_property
37
    def vocab(self) -> dict[str, int]:
38
39
40
41
        # NOTE: Only PreTrainedTokenizerFast is guaranteed to have .vocab
        # whereas all tokenizers have .get_vocab()
        return self.model_tokenizer.get_vocab()

42
43
44
45
46
47
48
49
50
51
52
53
54
55
    @property
    def reasoning_start_str(self) -> str | None:
        """Set `reasoning_start_str` to the strings that delimit
        the reasoning block (e.g. `""<seed:think>""` and `"<think>"`).
        """
        return None

    @property
    def reasoning_end_str(self) -> str | None:
        """Set `reasoning_end_str` to the strings that delimit
        the reasoning block (e.g. `""</seed:think>""` and `"</think>"`).
        """
        return None

56
    @abstractmethod
57
    def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
        """
        Check if the reasoning content ends in the input_ids.

        It is used in structured engines like `xgrammar` to check if the
        reasoning content ends in the model output.

        Parameters:
        input_ids: list[int]
            The input_ids of the model output.

        Returns:
        bool
            True if the reasoning content ends in the input_ids.
        """

73
    def is_reasoning_end_streaming(
74
        self, input_ids: Sequence[int], delta_ids: Iterable[int]
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
    ) -> bool:
        """
        Check if the reasoning content ends in the input_ids on a
        decode step.

        It is used in structured engines like `xgrammar` to check if the
        reasoning content ends in the model output during a decode step.
        `input_ids` the entire model output and `delta_ids` are the last few
        computed tokens of the model output (like during a decode step).

        Parameters:
        input_ids: list[int]
            The entire model output.
        delta_ids: list[int]
            The last few computed tokens of the model output at the current decode step.

        Returns:
        bool
            True if the reasoning content ends in the `delta_ids` on a
            decode step.
        """
        return self.is_reasoning_end(input_ids)

98
99
100
101
102
103
104
105
106
107
108
109
    @abstractmethod
    def extract_content_ids(self, input_ids: list[int]) -> list[int]:
        """
        Extract content token ids from the input_ids.
        Parameters:
        input_ids: list[int]
            The input_ids of the model output.
        Returns:
        list[int]
            The extracted content from the input_ids.
        """

110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
    def count_reasoning_tokens(self, token_ids: Sequence[int]) -> int:
        """Count the number of reasoning tokens in a sequence.

        Text-based reasoning models typically wrap their chain-of-thought
        between special start/end tokens (e.g., ``<think> ... </think>``).
        Implementations that support reasoning token counting should override
        this method. The default implementation returns ``0`` so existing
        parsers remain unchanged unless they explicitly opt in.

        Args:
            token_ids: Sequence of generated token ids (excluding prompt).

        Returns:
            int: Number of tokens that belong to reasoning content.
        """

        # By default, assume the parser cannot detect reasoning spans.
        return 0

129
    @abstractmethod
130
    def extract_reasoning(
131
132
        self,
        model_output: str,
133
        request: "ChatCompletionRequest | ResponsesRequest",
134
    ) -> tuple[str | None, str | None]:
135
136
137
138
139
140
141
        """
        Extract reasoning content from a complete model-generated string.

        Used for non-streaming responses where we have the entire model response
        available before sending to the client.

        Parameters:
142
143
            model_output: The model-generated string to extract reasoning content from.
            request: The request object that was used to generate the model_output.
144
145
146
147
148

        Returns:
            A tuple containing the reasoning content and the content.
        """

149
    @abstractmethod
150
    def extract_reasoning_streaming(
151
152
153
154
155
156
157
        self,
        previous_text: str,
        current_text: str,
        delta_text: str,
        previous_token_ids: Sequence[int],
        current_token_ids: Sequence[int],
        delta_token_ids: Sequence[int],
158
    ) -> "DeltaMessage | None":
159
160
161
162
163
164
165
        """
        Instance method that should be implemented for extracting reasoning
        from an incomplete response; for use when handling reasoning calls and
        streaming. Has to be an instance method because  it requires state -
        the current tokens/diffs, but also the information about what has
        previously been parsed and extracted (see constructor)
        """
166

167
168
169
170
171
172
    def adjust_request(
        self, request: "ChatCompletionRequest | ResponsesRequest"
    ) -> "ChatCompletionRequest | ResponsesRequest":
        """Adjust request parameters; override in subclasses as needed."""
        return request

173
174
175
176
    def prepare_structured_tag(
        self,
        original_tag: str | None,
        tool_server: ToolServer | None,
Ning Xie's avatar
Ning Xie committed
177
    ) -> str | None:
178
179
180
181
182
183
        """
        Instance method that is implemented for preparing the structured tag
        Otherwise, None is returned
        """
        return None

184
185

class ReasoningParserManager:
186
187
188
189
190
191
192
193
194
195
196
197
    """
    Central registry for ReasoningParser implementations.

    Supports two registration modes:
      - Eager registration via `register_module`
      - Lazy registration via `register_lazy_module`

    Each reasoning parser must inherit from `ReasoningParser`.
    """

    reasoning_parsers: dict[str, type[ReasoningParser]] = {}
    lazy_parsers: dict[str, tuple[str, str]] = {}  # name -> (module_path, class_name)
198
199

    @classmethod
200
    def get_reasoning_parser(cls, name: str) -> type[ReasoningParser]:
201
        """
202
203
204
205
        Retrieve a registered or lazily registered ReasoningParser class.

        If the parser is lazily registered, it will be imported and cached
        on first access.
206

207
208
        Raises:
            KeyError: if no parser is found under the given name.
209
210
211
212
        """
        if name in cls.reasoning_parsers:
            return cls.reasoning_parsers[name]

213
214
215
        if name in cls.lazy_parsers:
            return cls._load_lazy_parser(name)

216
217
218
219
        registered = ", ".join(cls.list_registered())
        raise KeyError(
            f"Reasoning parser '{name}' not found. Available parsers: {registered}"
        )
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247

    @classmethod
    def list_registered(cls) -> list[str]:
        """Return names of all eagerly and lazily registered reasoning parsers."""
        return sorted(set(cls.reasoning_parsers.keys()) | set(cls.lazy_parsers.keys()))

    @classmethod
    def _load_lazy_parser(cls, name: str) -> type[ReasoningParser]:
        """Import and register a lazily loaded reasoning parser."""
        module_path, class_name = cls.lazy_parsers[name]
        try:
            mod = importlib.import_module(module_path)
            parser_cls = getattr(mod, class_name)
            if not issubclass(parser_cls, ReasoningParser):
                raise TypeError(
                    f"{class_name} in {module_path} is not a ReasoningParser subclass."
                )

            cls.reasoning_parsers[name] = parser_cls  # cache
            return parser_cls
        except Exception as e:
            logger.exception(
                "Failed to import lazy reasoning parser '%s' from %s: %s",
                name,
                module_path,
                e,
            )
            raise
248
249

    @classmethod
250
251
    def _register_module(
        cls,
252
        module: type[ReasoningParser],
253
        module_name: str | list[str] | None = None,
254
255
        force: bool = True,
    ) -> None:
256
        """Register a ReasoningParser class immediately."""
257
        if not issubclass(module, ReasoningParser):
258
259
260
            raise TypeError(
                f"module must be subclass of ReasoningParser, but got {type(module)}"
            )
261

262
        if module_name is None:
263
264
265
266
267
268
269
270
271
            module_names = [module.__name__]
        elif isinstance(module_name, str):
            module_names = [module_name]
        elif is_list_of(module_name, str):
            module_names = module_name
        else:
            raise TypeError("module_name must be str, list[str], or None.")

        for name in module_names:
272
            if not force and name in cls.reasoning_parsers:
273
274
                existed = cls.reasoning_parsers[name]
                raise KeyError(f"{name} is already registered at {existed.__module__}")
275
276
            cls.reasoning_parsers[name] = module

277
278
279
280
281
282
283
284
285
286
287
288
289
290
    @classmethod
    def register_lazy_module(cls, name: str, module_path: str, class_name: str) -> None:
        """
        Register a lazy module mapping for delayed import.

        Example:
            ReasoningParserManager.register_lazy_module(
                name="qwen3",
                module_path="vllm.reasoning.parsers.qwen3_reasoning_parser",
                class_name="Qwen3ReasoningParser",
            )
        """
        cls.lazy_parsers[name] = (module_path, class_name)

291
292
    @classmethod
    def register_module(
293
        cls,
294
        name: str | list[str] | None = None,
295
        force: bool = True,
296
297
298
299
        module: type[ReasoningParser] | None = None,
    ) -> (
        type[ReasoningParser] | Callable[[type[ReasoningParser]], type[ReasoningParser]]
    ):
300
301
        """
        Register module with the given name or name list. it can be used as a
302
        decoder(with module as None) or normal function(with module as not
303
304
305
306
307
        None).
        """
        if not isinstance(force, bool):
            raise TypeError(f"force must be a boolean, but got {type(force)}")

308
        # Immediate registration (explicit call)
309
310
311
312
        if module is not None:
            cls._register_module(module=module, module_name=name, force=force)
            return module

313
314
315
316
317
318
319
320
        # Decorator usage
        def _decorator(obj: type[ReasoningParser]) -> type[ReasoningParser]:
            module_path = obj.__module__
            class_name = obj.__name__

            if isinstance(name, str):
                names = [name]
            elif is_list_of(name, str):
321
                names = cast(list[str], name)
322
323
324
325
326
327
328
            else:
                names = [class_name]

            for n in names:
                cls.lazy_parsers[n] = (module_path, class_name)

            return obj
329

330
        return _decorator
331
332
333
334

    @classmethod
    def import_reasoning_parser(cls, plugin_path: str) -> None:
        """
335
        Import a user-defined reasoning parser by the path
336
337
338
339
340
341
342
        of the reasoning parser define file.
        """
        module_name = os.path.splitext(os.path.basename(plugin_path))[0]

        try:
            import_from_path(module_name, plugin_path)
        except Exception:
343
344
345
            logger.exception(
                "Failed to load module '%s' from %s.", module_name, plugin_path
            )
346
            return