# SPDX-License-Identifier: Apache-2.0 import re from collections.abc import Sequence from typing import Optional, Union from transformers import PreTrainedTokenizerBase from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaMessage) from vllm.logger import init_logger from vllm.reasoning import ReasoningParser, ReasoningParserManager logger = init_logger(__name__) @ReasoningParserManager.register_module("qwen3") class Qwen3ReasoningParser(ReasoningParser): """ Reasoning parser for the Qwen3 model. The Qwen3 model uses ... tokens to denote reasoning text within its output. The model provides a strict switch to disable reasoning output via the 'enable_thinking=False' parameter. This parser extracts the reasoning content enclosed by and tokens from the model's output. """ def __init__(self, tokenizer: PreTrainedTokenizerBase): super().__init__(tokenizer) self.think_start_token = "" self.think_end_token = "" self.reasoning_regex = re.compile( rf"{self.think_start_token}(.*?){self.think_end_token}", re.DOTALL) if not self.model_tokenizer: raise ValueError( "The model tokenizer must be passed to the ReasoningParser " "constructor during construction.") self.think_start_token_id = self.vocab.get(self.think_start_token) self.think_end_token_id = self.vocab.get(self.think_end_token) if (self.think_start_token_id is None or self.think_end_token_id is None): raise RuntimeError( "Qwen3 reasoning parser could not locate think start/end " "tokens in the tokenizer!") def extract_reasoning_content_streaming( self, previous_text: str, current_text: str, delta_text: str, previous_token_ids: Sequence[int], current_token_ids: Sequence[int], delta_token_ids: Sequence[int], ) -> Union[DeltaMessage, None]: """ Extract reasoning content from a delta message. Handles streaming output where previous + delta = current. Uses token IDs for faster processing. For text abcxyz: - 'abc' goes to reasoning_content - 'xyz' goes to content """ # Skip single special tokens if len(delta_token_ids) == 1 and (delta_token_ids[0] in [ self.think_start_token_id, self.think_end_token_id ]): return None if self.think_start_token_id in previous_token_ids: if self.think_end_token_id in delta_token_ids: # in previous, in delta, # extract reasoning content end_index = delta_text.find(self.think_end_token) reasoning_content = delta_text[:end_index] content = delta_text[end_index + len(self.think_end_token):] return DeltaMessage(reasoning_content=reasoning_content, content=content if content else None) elif self.think_end_token_id in previous_token_ids: # in previous, in previous, # reasoning content continues return DeltaMessage(content=delta_text) else: # in previous, no in previous or delta, # reasoning content continues return DeltaMessage(reasoning_content=delta_text) elif self.think_start_token_id in delta_token_ids: logger.info(delta_text) if self.think_end_token_id in delta_token_ids: # in delta, in delta, extract reasoning content start_index = delta_text.find(self.think_start_token) end_index = delta_text.find(self.think_end_token) reasoning_content = delta_text[start_index + len(self.think_start_token ):end_index] content = delta_text[end_index + len(self.think_end_token):] return DeltaMessage(reasoning_content=reasoning_content, content=content if content else None) else: # in delta, no in delta, # reasoning content continues return DeltaMessage(reasoning_content=delta_text) else: # thinking is disabled, just content return DeltaMessage(content=delta_text) def extract_reasoning_content( self, model_output: str, request: ChatCompletionRequest ) -> tuple[Optional[str], Optional[str]]: # Check if the model output contains the tokens. if (self.think_start_token not in model_output or self.think_end_token not in model_output): return None, model_output else: # Use a regex to find the reasoning content reasoning_content = self.reasoning_regex.findall(model_output)[0] # Remove the reasoning content from the model output # Although token is always at the # beginning of the line, we cannot guarantee that the # other models will follow this convention. # Therefore, we need to add :start_index. start_index = model_output.find(self.think_start_token) if start_index != -1: end_index = start_index + len( f"{self.think_start_token}{reasoning_content}{self.think_end_token}" ) model_output = model_output[:start_index] + \ model_output[end_index:] if len(model_output) == 0: return reasoning_content, None return reasoning_content, model_output