Unverified Commit 32b14baf authored by Ce Gao's avatar Ce Gao Committed by GitHub
Browse files

[Refactor][Frontend] Keep all logic about reasoning into one class (#14428)


Signed-off-by: default avatarCe Gao <cegao@tensorchord.ai>
parent 2d9045fc
......@@ -3,74 +3,92 @@
import pytest
from transformers import AutoTokenizer
from tests.entrypoints.openai.reasoning_parsers.utils import (
run_reasoning_extraction)
from vllm.entrypoints.openai.reasoning_parsers import (ReasoningParser,
ReasoningParserManager)
from tests.reasoning.utils import run_reasoning_extraction
from vllm.reasoning import ReasoningParser, ReasoningParserManager
parser_name = "deepseek_r1"
start_token = "<think>"
end_token = "</think>"
REASONING_MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
@pytest.fixture(scope="module")
def deepseek_r1_qwen_tokenizer():
return AutoTokenizer.from_pretrained(REASONING_MODEL_NAME)
SIMPLE_REASONING = {
"output": "This is a reasoning section</think>This is the rest",
"reasoning_content": "This is a reasoning section",
"content": "This is the rest",
"is_reasoning_end": True,
}
COMPLETE_REASONING = {
"output": "This is a reasoning section</think>",
"reasoning_content": "This is a reasoning section",
"content": None,
"is_reasoning_end": True,
}
NO_CONTENT = {
"output": "This is content",
"reasoning_content": "This is content",
"content": None,
"is_reasoning_end": False,
}
NO_REASONING_STREAMING = {
"output": "This is a reasoning section",
"reasoning_content": "This is a reasoning section",
"content": None,
"is_reasoning_end": False,
}
MULTIPLE_LINES = {
"output": "This\nThat</think>This is the rest\nThat",
"reasoning_content": "This\nThat",
"content": "This is the rest\nThat",
"is_reasoning_end": True,
}
SHORTEST_REASONING_NO_STREAMING = {
"output": "</think>This is the rest",
"reasoning_content": "",
"content": "This is the rest",
"is_reasoning_end": True,
}
SHORTEST_REASONING = {
"output": "</think>This is the rest",
"reasoning_content": None,
"content": "This is the rest",
"is_reasoning_end": True,
}
REASONING_WITH_THINK = {
"output": "<think>This is a reasoning section</think>This is the rest",
"reasoning_content": "This is a reasoning section",
"content": "This is the rest",
"is_reasoning_end": True,
}
COMPLETE_REASONING_WITH_THINK = {
"output": "<think>This is a reasoning section</think>",
"reasoning_content": "This is a reasoning section",
"content": None,
"is_reasoning_end": True,
}
MULTIPLE_LINES_WITH_THINK = {
"output": "<think>This\nThat</think>This is the rest\nThat",
"reasoning_content": "This\nThat",
"content": "This is the rest\nThat",
"is_reasoning_end": True,
}
SHORTEST_REASONING_NO_STREAMING_WITH_THINK = {
"output": "</think>This is the rest",
"reasoning_content": "",
"content": "This is the rest",
"is_reasoning_end": True,
}
SHORTEST_REASONING_WITH_THINK = {
"output": "</think>This is the rest",
"reasoning_content": None,
"content": "This is the rest",
"is_reasoning_end": True,
}
TEST_CASES = [
......@@ -166,23 +184,21 @@ TEST_CASES = [
),
]
# Global tokenizer initialization to avoid repeated loading
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
tokenizer.add_tokens([start_token, end_token])
@pytest.mark.parametrize("streaming, param_dict", TEST_CASES)
def test_reasoning(
streaming: bool,
param_dict: dict,
deepseek_r1_qwen_tokenizer,
):
output = tokenizer.tokenize(param_dict["output"])
output = deepseek_r1_qwen_tokenizer.tokenize(param_dict["output"])
# decode everything to tokens
output_tokens: list[str] = [
tokenizer.convert_tokens_to_string([token]) for token in output
deepseek_r1_qwen_tokenizer.convert_tokens_to_string([token])
for token in output
]
parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(
parser_name)(tokenizer)
parser_name)(deepseek_r1_qwen_tokenizer)
reasoning, content = run_reasoning_extraction(parser,
output_tokens,
......@@ -190,3 +206,17 @@ def test_reasoning(
assert reasoning == param_dict["reasoning_content"]
assert content == param_dict["content"]
# Test is_reasoning_end
output_ids = deepseek_r1_qwen_tokenizer.convert_tokens_to_ids(output)
is_reasoning_end = parser.is_reasoning_end(output_ids)
assert is_reasoning_end == param_dict["is_reasoning_end"]
# Test extract_content
if param_dict["content"] is not None:
content = parser.extract_content_ids(output_ids)
assert content == deepseek_r1_qwen_tokenizer.convert_tokens_to_ids(
deepseek_r1_qwen_tokenizer.tokenize(param_dict["content"]))
else:
content = parser.extract_content_ids(output)
assert content == []
......@@ -2,10 +2,8 @@
import pytest
from transformers import AutoTokenizer
from tests.entrypoints.openai.reasoning_parsers.utils import (
DeltaMessage, run_reasoning_extraction)
from vllm.entrypoints.openai.reasoning_parsers import (ReasoningParser,
ReasoningParserManager)
from tests.reasoning.utils import DeltaMessage, run_reasoning_extraction
from vllm.reasoning import ReasoningParser, ReasoningParserManager
parser_name = "granite"
START_REASONING = "Here is my thought process:"
......
......@@ -4,7 +4,7 @@ from typing import Optional, Union
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaMessage)
from vllm.entrypoints.openai.reasoning_parsers import ReasoningParser
from vllm.reasoning import ReasoningParser
class StreamingReasoningReconstructor:
......
......@@ -23,6 +23,7 @@ from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.plugins import load_general_plugins
from vllm.reasoning import ReasoningParserManager
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
from vllm.transformers_utils.utils import check_gguf_file
from vllm.usage.usage_lib import UsageContext
......@@ -1119,7 +1120,7 @@ class EngineArgs:
parser.add_argument(
"--reasoning-parser",
type=str,
choices=["deepseek_r1", "granite"],
choices=list(ReasoningParserManager.reasoning_parsers),
default=None,
help=
"Select the reasoning parser depending on the model that you're "
......
......@@ -2080,7 +2080,8 @@ class LLMEngine:
guided_decoding.backend = guided_decoding.backend or \
self.decoding_config.guided_decoding_backend
logger.debug("Reasoning backend: %s",
if self.decoding_config.reasoning_backend is not None:
logger.debug("Building with reasoning backend %s",
self.decoding_config.reasoning_backend)
processor = get_local_guided_decoding_logits_processor(
......
......@@ -68,7 +68,6 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
TranscriptionRequest,
TranscriptionResponse,
UnloadLoRAAdapterRequest)
from vllm.entrypoints.openai.reasoning_parsers import ReasoningParserManager
# yapf: enable
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
......@@ -85,6 +84,7 @@ from vllm.entrypoints.openai.serving_transcription import (
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
from vllm.entrypoints.utils import load_aware_call, with_cancellation
from vllm.logger import init_logger
from vllm.reasoning import ReasoningParserManager
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
from vllm.transformers_utils.tokenizer import MistralTokenizer
......
......@@ -23,8 +23,6 @@ from vllm.entrypoints.openai.protocol import (
ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage,
DeltaToolCall, ErrorResponse, FunctionCall, PromptTokenUsageInfo,
RequestResponseMetadata, ToolCall, UsageInfo)
from vllm.entrypoints.openai.reasoning_parsers import (ReasoningParser,
ReasoningParserManager)
from vllm.entrypoints.openai.serving_engine import (OpenAIServing,
clamp_prompt_logprobs)
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
......@@ -33,6 +31,7 @@ from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import (
MistralToolCall)
from vllm.logger import init_logger
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.reasoning import ReasoningParser, ReasoningParserManager
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
......
......@@ -5,10 +5,10 @@ from __future__ import annotations
from typing import TYPE_CHECKING
from vllm.logger import init_logger
from vllm.model_executor.guided_decoding.reasoner import get_reasoner
from vllm.model_executor.guided_decoding.utils import (
convert_lark_to_gbnf, grammar_is_likely_lark,
has_lmf_unsupported_json_features, has_xgrammar_unsupported_json_features)
from vllm.reasoning import ReasoningParserManager
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
......@@ -107,7 +107,11 @@ async def get_guided_decoding_logits_processor(
model_config: ModelConfig,
reasoning_backend: str | None = None) -> LogitsProcessor | None:
reasoner = get_reasoner(tokenizer, reasoning_backend)
reasoner = None
if reasoning_backend is not None:
reasoner_class = ReasoningParserManager.get_reasoning_parser(
reasoning_backend)
reasoner = reasoner_class(tokenizer)
guided_params = maybe_backend_fallback(guided_params)
......@@ -146,8 +150,11 @@ def get_local_guided_decoding_logits_processor(
reasoning_backend: str | None = None) -> LogitsProcessor | None:
guided_params = maybe_backend_fallback(guided_params)
# Get the reasoner if needed, it will be None if reasoning_
reasoner = get_reasoner(tokenizer, reasoning_backend)
reasoner = None
if reasoning_backend is not None:
reasoner_class = ReasoningParserManager.get_reasoning_parser(
reasoning_backend)
reasoner = reasoner_class(tokenizer)
# CFG grammar not supported by LMFE, so we use outlines instead
if guided_params.backend_name == 'outlines':
......
......@@ -12,7 +12,7 @@ from transformers import PreTrainedTokenizerBase
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor)
from vllm.model_executor.guided_decoding.reasoner import Reasoner
from vllm.reasoning import ReasoningParser
from vllm.sampling_params import GuidedDecodingParams
......@@ -61,7 +61,7 @@ _MAX_THREADPOOL_WORKERS = 16
async def get_outlines_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams,
tokenizer: PreTrainedTokenizerBase,
reasoner: Optional[Reasoner],
reasoner: Optional[ReasoningParser],
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
None]:
"""
......@@ -92,7 +92,7 @@ async def get_outlines_guided_decoding_logits_processor(
def get_local_outlines_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams,
tokenizer: PreTrainedTokenizerBase,
reasoner: Optional[Reasoner],
reasoner: Optional[ReasoningParser],
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
None]:
"""
......@@ -141,7 +141,7 @@ def _get_logits_processor(
tokenizer: PreTrainedTokenizerBase,
mode: GuidedDecodingMode,
whitespace_pattern: Union[str, None],
reasoner: Optional[Reasoner],
reasoner: Optional[ReasoningParser],
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor]:
if mode == GuidedDecodingMode.JSON:
return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern,
......
......@@ -34,8 +34,8 @@ from transformers import PreTrainedTokenizerBase
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.model_executor.guided_decoding.reasoner import Reasoner
from vllm.platforms import current_platform
from vllm.reasoning import ReasoningParser
logger = init_logger(__name__)
......@@ -49,9 +49,9 @@ else:
class BaseLogitsProcessor:
def __init__(self, guide: Guide, reasoner: Optional[Reasoner]):
def __init__(self, guide: Guide, reasoner: Optional[ReasoningParser]):
self._guide: Guide = guide
self._reasoner: Optional[Reasoner] = reasoner
self._reasoner: Optional[ReasoningParser] = reasoner
# CFGState is used for the FSM state for CFGGuide
self._fsm_state: DefaultDict[int, Union[int,
CFGState]] = defaultdict(int)
......@@ -69,7 +69,7 @@ class BaseLogitsProcessor:
# Remove the reasoning tokens from the input_ids
# We need this because our implementation relies on the
# hash of the input_ids to store the FSM state.
input_ids = self._reasoner.extract_content(input_ids)
input_ids = self._reasoner.extract_content_ids(input_ids)
seq_id = hash(tuple(input_ids))
......@@ -142,7 +142,7 @@ class RegexLogitsProcessor(BaseLogitsProcessor):
self,
regex_string: str,
tokenizer: PreTrainedTokenizerBase,
reasoner: Optional[Reasoner],
reasoner: Optional[ReasoningParser],
):
"""Compile the FSM that drives the regex-structured generation.
......@@ -163,7 +163,7 @@ class JSONLogitsProcessor(RegexLogitsProcessor):
def __init__(self, schema: Union[str, Dict, BaseModel],
tokenizer: PreTrainedTokenizerBase,
whitespace_pattern: Union[str, None],
reasoner: Optional[Reasoner]):
reasoner: Optional[ReasoningParser]):
"""Compile the FSM that drives the JSON-guided generation.
Parameters
......@@ -203,7 +203,7 @@ class CFGLogitsProcessor(BaseLogitsProcessor):
return CFGGuide(cfg, tokenizer)
def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase,
reasoner: Optional[Reasoner]):
reasoner: Optional[ReasoningParser]):
"""Compile the FSM that drives the context free grammar generation.
Parameters
......
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
from transformers import PreTrainedTokenizer
from vllm.model_executor.guided_decoding.reasoner.reasoner import Reasoner
@dataclass
class DeepSeekReasoner(Reasoner):
"""
Reasoner for DeepSeek R series models.
"""
start_token_id: int
end_token_id: int
start_token: str = "<think>"
end_token: str = "</think>"
@classmethod
def from_tokenizer(cls, tokenizer: PreTrainedTokenizer) -> Reasoner:
return cls(start_token_id=tokenizer.encode(
"<think>", add_special_tokens=False)[0],
end_token_id=tokenizer.encode("</think>",
add_special_tokens=False)[0])
def is_reasoning_end(self, input_ids: list[int]) -> bool:
return self.end_token_id in input_ids
def extract_content(self, input_ids: list[int]) -> list[int]:
"""
Extract the content after the end tokens
"""
if self.end_token_id not in input_ids or \
input_ids.index(self.end_token_id) + 1 == len(input_ids):
return []
else:
return input_ids[input_ids.index(self.end_token_id) + 1:]
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
from transformers import PreTrainedTokenizer
@dataclass
class Reasoner(ABC):
@abstractmethod
def from_tokenizer(cls, tokenizer: PreTrainedTokenizer) -> Reasoner:
pass
@abstractmethod
def is_reasoning_end(self, input_ids: list[int]) -> bool:
pass
@abstractmethod
def extract_content(self, input_ids: list[int]) -> list[int]:
pass
......@@ -27,7 +27,7 @@ if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
from vllm.config import ModelConfig
from vllm.model_executor.guided_decoding.reasoner import Reasoner
from vllm.reasoning import ReasoningParser
from vllm.sampling_params import GuidedDecodingParams
logger = init_logger(__name__)
......@@ -37,7 +37,7 @@ def get_local_xgrammar_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams,
tokenizer: PreTrainedTokenizer,
model_config: ModelConfig,
reasoner: Reasoner | None,
reasoner: ReasoningParser | None,
max_threads: int = 8):
config = GrammarConfig.from_guided_params(guided_params=guided_params,
model_config=model_config,
......@@ -280,7 +280,7 @@ class GrammarConfig:
class XGrammarLogitsProcessor:
"""Wrapper class to support pickle protocol"""
config: GrammarConfig
reasoner: Reasoner | None = None
reasoner: ReasoningParser | None = None
ctx: xgr.CompiledGrammar | None = None
tokenizer_info: xgr.TokenizerInfo = None # type: ignore[assignment]
......
......@@ -32,6 +32,36 @@ class ReasoningParser:
# whereas all tokenizers have .get_vocab()
return self.model_tokenizer.get_vocab()
@abstractmethod
def is_reasoning_end(self, input_ids: list[int]) -> bool:
"""
Check if the reasoning content ends in the input_ids.
It is used in structured engines like `xgrammar` to check if the
reasoning content ends in the model output.
Parameters:
input_ids: list[int]
The input_ids of the model output.
Returns:
bool
True if the reasoning content ends in the input_ids.
"""
@abstractmethod
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
"""
Extract content token ids from the input_ids.
Parameters:
input_ids: list[int]
The input_ids of the model output.
Returns:
list[int]
The extracted content from the input_ids.
"""
@abstractmethod
def extract_reasoning_content(
self, model_output: str, request: ChatCompletionRequest
) -> tuple[Optional[str], Optional[str]]:
......@@ -53,10 +83,7 @@ class ReasoningParser:
A tuple containing the reasoning content and the content.
"""
raise NotImplementedError(
"AbstractReasoningParser.extract_reasoning_calls "
"has not been implemented!")
@abstractmethod
def extract_reasoning_content_streaming(
self,
previous_text: str,
......@@ -73,43 +100,6 @@ class ReasoningParser:
the current tokens/diffs, but also the information about what has
previously been parsed and extracted (see constructor)
"""
raise NotImplementedError(
"AbstractReasoningParser.extract_reasoning_content_streaming "
"has not been implemented!")
# TODO: need to rebase by PR #14428
@abstractmethod
def is_reasoning_end(self, input_ids: list[int]) -> bool:
"""
Check if the reasoning content ends in the input_ids.
Parameters:
input_ids: list[int]
The input_ids of the model output.
Returns:
bool
True if the reasoning content ends in the input_ids.
"""
raise NotImplementedError(
"AbstractReasoningParser.is_reasoning_end has"
"not been implemented!")
# TODO: need to rebase by PR #14428
@abstractmethod
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
"""
Extract content token ids from the input_ids.
Parameters:
input_ids: list[int]
The input_ids of the model output.
Returns:
list[int]
The extracted content from the input_ids.
"""
raise NotImplementedError(
"AbstractReasoningParser.extract_content_ids has"
" not been implemented!")
class ReasoningParserManager:
......@@ -125,14 +115,16 @@ class ReasoningParserManager:
if name in cls.reasoning_parsers:
return cls.reasoning_parsers[name]
raise KeyError(f"reasoning helper: '{name}' not found in "
"reasoning_parsers")
raise KeyError(
f"reasoning helper: '{name}' not found in reasoning_parsers")
@classmethod
def _register_module(cls,
def _register_module(
cls,
module: type,
module_name: Optional[Union[str, list[str]]] = None,
force: bool = True) -> None:
force: bool = True,
) -> None:
if not issubclass(module, ReasoningParser):
raise TypeError("module must be subclass of ReasoningParser, "
f"but got {type(module)}")
......@@ -152,7 +144,8 @@ class ReasoningParserManager:
cls,
name: Optional[Union[str, list[str]]] = None,
force: bool = True,
module: Union[type, None] = None) -> Union[type, Callable]:
module: Union[type, None] = None,
) -> Union[type, Callable]:
"""
Register module with the given name or name list. it can be used as a
decoder(with module as None) or normal function(with module as not
......
......@@ -8,9 +8,8 @@ from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaMessage)
from vllm.entrypoints.openai.reasoning_parsers.abs_reasoning_parsers import (
ReasoningParser, ReasoningParserManager)
from vllm.logger import init_logger
from vllm.reasoning import ReasoningParser, ReasoningParserManager
logger = init_logger(__name__)
......@@ -24,39 +23,41 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
text. This parser extracts the reasoning content from the model output.
"""
start_token_id: int
end_token_id: int
start_token: str = "<think>"
end_token: str = "</think>"
def __init__(self, tokenizer: PreTrainedTokenizerBase):
super().__init__(tokenizer)
self.think_start_token = "<think>"
self.think_end_token = "</think>"
self.reasoning_regex = re.compile(
rf"{self.think_start_token}(.*?){self.think_end_token}", re.DOTALL)
rf"{self.start_token}(.*?){self.end_token}", re.DOTALL)
if not self.model_tokenizer:
raise ValueError(
"The model tokenizer must be passed to the ReasoningParser "
"constructor during construction.")
self.think_start_token_id = self.vocab.get(self.think_start_token)
self.think_end_token_id = self.vocab.get(self.think_end_token)
if (self.think_start_token_id is None
or self.think_end_token_id is None):
self.start_token_id = self.vocab.get(self.start_token)
self.end_token_id = self.vocab.get(self.end_token)
if self.start_token_id is None or self.end_token_id is None:
raise RuntimeError(
"DeepSeek R1 reasoning parser could not locate think start/end "
"tokens in the tokenizer!")
# TODO: need to rebase by PR #14428
def is_reasoning_end(self, input_ids: list[int]) -> bool:
return self.think_end_token_id in input_ids
return self.end_token_id in input_ids
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
"""
Extract the content after the end tokens
"""
if self.think_end_token_id not in input_ids[:-1]:
if self.end_token_id not in input_ids[:-1]:
return []
else:
return input_ids[input_ids.index(self.think_end_token_id) + 1:]
return input_ids[input_ids.index(self.end_token_id) + 1:]
def extract_reasoning_content_streaming(
self,
......@@ -77,22 +78,24 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
"""
# Skip single special tokens
if len(delta_token_ids) == 1 and (delta_token_ids[0] in [
self.think_start_token_id, self.think_end_token_id
self.start_token_id, self.end_token_id
]):
return None
# Check if <think> is present in previous or delta.
# Keep compatibility with models that don't generate <think> tokens.
if self.think_start_token_id in previous_token_ids:
if self.think_end_token_id in delta_token_ids:
if self.start_token_id in previous_token_ids:
if self.end_token_id in delta_token_ids:
# <think> in previous, </think> in delta,
# extract reasoning content
end_index = delta_text.find(self.think_end_token)
end_index = delta_text.find(self.end_token)
reasoning_content = delta_text[:end_index]
content = delta_text[end_index + len(self.think_end_token):]
return DeltaMessage(reasoning_content=reasoning_content,
content=content if content else None)
elif self.think_end_token_id in previous_token_ids:
content = delta_text[end_index + len(self.end_token):]
return DeltaMessage(
reasoning_content=reasoning_content,
content=content if content else None,
)
elif self.end_token_id in previous_token_ids:
# <think> in previous, </think> in previous,
# reasoning content continues
return DeltaMessage(content=delta_text)
......@@ -100,17 +103,18 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
# <think> in previous, no </think> in previous or delta,
# reasoning content continues
return DeltaMessage(reasoning_content=delta_text)
elif self.think_start_token_id in delta_token_ids:
if self.think_end_token_id in delta_token_ids:
elif self.start_token_id in delta_token_ids:
if self.end_token_id in delta_token_ids:
# <think> in delta, </think> in delta, extract reasoning content
start_index = delta_text.find(self.think_start_token)
end_index = delta_text.find(self.think_end_token)
start_index = delta_text.find(self.start_token)
end_index = delta_text.find(self.end_token)
reasoning_content = delta_text[start_index +
len(self.think_start_token
):end_index]
content = delta_text[end_index + len(self.think_end_token):]
return DeltaMessage(reasoning_content=reasoning_content,
content=content if content else None)
len(self.start_token):end_index]
content = delta_text[end_index + len(self.end_token):]
return DeltaMessage(
reasoning_content=reasoning_content,
content=content if content else None,
)
else:
# <think> in delta, no </think> in delta,
# reasoning content continues
......@@ -119,15 +123,17 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
# No <think> in previous or delta, also need to check for </think>.
# Because the model may have generated </think> without <think>
# Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f
if self.think_end_token_id in delta_token_ids:
if self.end_token_id in delta_token_ids:
# </think> in delta with more tokens,
# extract reasoning content and content
end_index = delta_text.find(self.think_end_token)
end_index = delta_text.find(self.end_token)
reasoning_content = delta_text[:end_index]
content = delta_text[end_index + len(self.think_end_token):]
return DeltaMessage(reasoning_content=reasoning_content,
content=content if content else None)
elif self.think_end_token_id in previous_token_ids:
content = delta_text[end_index + len(self.end_token):]
return DeltaMessage(
reasoning_content=reasoning_content,
content=content if content else None,
)
elif self.end_token_id in previous_token_ids:
# </think> in previous, thinking content ends
return DeltaMessage(content=delta_text)
else:
......@@ -137,22 +143,20 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
def extract_reasoning_content(
self, model_output: str, request: ChatCompletionRequest
) -> tuple[Optional[str], Optional[str]]:
# DeepSeek R1 doesn't generate <think> now.
# Thus we assume the reasoning content is always at the start.
# Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f
if self.think_end_token not in model_output:
if self.end_token not in model_output:
return model_output, None
else:
# Add a start token if it's missing to keep compatibility.
if self.think_start_token not in model_output:
model_output = f"{self.think_start_token}{model_output}"
if self.start_token not in model_output:
model_output = f"{self.start_token}{model_output}"
# Use a regex to find the reasoning content
reasoning_content = self.reasoning_regex.findall(model_output)[0]
end_index = len(
f"{self.think_start_token}{reasoning_content}{self.think_end_token}"
)
f"{self.start_token}{reasoning_content}{self.end_token}")
final_output = model_output[end_index:]
if len(final_output) == 0:
......
......@@ -8,9 +8,8 @@ from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaMessage)
from vllm.entrypoints.openai.reasoning_parsers.abs_reasoning_parsers import (
ReasoningParser, ReasoningParserManager)
from vllm.logger import init_logger
from vllm.reasoning import ReasoningParser, ReasoningParserManager
logger = init_logger(__name__)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment