Unverified Commit 32b14baf authored by Ce Gao's avatar Ce Gao Committed by GitHub
Browse files

[Refactor][Frontend] Keep all logic about reasoning into one class (#14428)


Signed-off-by: default avatarCe Gao <cegao@tensorchord.ai>
parent 2d9045fc
...@@ -3,74 +3,92 @@ ...@@ -3,74 +3,92 @@
import pytest import pytest
from transformers import AutoTokenizer from transformers import AutoTokenizer
from tests.entrypoints.openai.reasoning_parsers.utils import ( from tests.reasoning.utils import run_reasoning_extraction
run_reasoning_extraction) from vllm.reasoning import ReasoningParser, ReasoningParserManager
from vllm.entrypoints.openai.reasoning_parsers import (ReasoningParser,
ReasoningParserManager)
parser_name = "deepseek_r1" parser_name = "deepseek_r1"
start_token = "<think>" start_token = "<think>"
end_token = "</think>" end_token = "</think>"
REASONING_MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
@pytest.fixture(scope="module")
def deepseek_r1_qwen_tokenizer():
return AutoTokenizer.from_pretrained(REASONING_MODEL_NAME)
SIMPLE_REASONING = { SIMPLE_REASONING = {
"output": "This is a reasoning section</think>This is the rest", "output": "This is a reasoning section</think>This is the rest",
"reasoning_content": "This is a reasoning section", "reasoning_content": "This is a reasoning section",
"content": "This is the rest", "content": "This is the rest",
"is_reasoning_end": True,
} }
COMPLETE_REASONING = { COMPLETE_REASONING = {
"output": "This is a reasoning section</think>", "output": "This is a reasoning section</think>",
"reasoning_content": "This is a reasoning section", "reasoning_content": "This is a reasoning section",
"content": None, "content": None,
"is_reasoning_end": True,
} }
NO_CONTENT = { NO_CONTENT = {
"output": "This is content", "output": "This is content",
"reasoning_content": "This is content", "reasoning_content": "This is content",
"content": None, "content": None,
"is_reasoning_end": False,
} }
NO_REASONING_STREAMING = { NO_REASONING_STREAMING = {
"output": "This is a reasoning section", "output": "This is a reasoning section",
"reasoning_content": "This is a reasoning section", "reasoning_content": "This is a reasoning section",
"content": None, "content": None,
"is_reasoning_end": False,
} }
MULTIPLE_LINES = { MULTIPLE_LINES = {
"output": "This\nThat</think>This is the rest\nThat", "output": "This\nThat</think>This is the rest\nThat",
"reasoning_content": "This\nThat", "reasoning_content": "This\nThat",
"content": "This is the rest\nThat", "content": "This is the rest\nThat",
"is_reasoning_end": True,
} }
SHORTEST_REASONING_NO_STREAMING = { SHORTEST_REASONING_NO_STREAMING = {
"output": "</think>This is the rest", "output": "</think>This is the rest",
"reasoning_content": "", "reasoning_content": "",
"content": "This is the rest", "content": "This is the rest",
"is_reasoning_end": True,
} }
SHORTEST_REASONING = { SHORTEST_REASONING = {
"output": "</think>This is the rest", "output": "</think>This is the rest",
"reasoning_content": None, "reasoning_content": None,
"content": "This is the rest", "content": "This is the rest",
"is_reasoning_end": True,
} }
REASONING_WITH_THINK = { REASONING_WITH_THINK = {
"output": "<think>This is a reasoning section</think>This is the rest", "output": "<think>This is a reasoning section</think>This is the rest",
"reasoning_content": "This is a reasoning section", "reasoning_content": "This is a reasoning section",
"content": "This is the rest", "content": "This is the rest",
"is_reasoning_end": True,
} }
COMPLETE_REASONING_WITH_THINK = { COMPLETE_REASONING_WITH_THINK = {
"output": "<think>This is a reasoning section</think>", "output": "<think>This is a reasoning section</think>",
"reasoning_content": "This is a reasoning section", "reasoning_content": "This is a reasoning section",
"content": None, "content": None,
"is_reasoning_end": True,
} }
MULTIPLE_LINES_WITH_THINK = { MULTIPLE_LINES_WITH_THINK = {
"output": "<think>This\nThat</think>This is the rest\nThat", "output": "<think>This\nThat</think>This is the rest\nThat",
"reasoning_content": "This\nThat", "reasoning_content": "This\nThat",
"content": "This is the rest\nThat", "content": "This is the rest\nThat",
"is_reasoning_end": True,
} }
SHORTEST_REASONING_NO_STREAMING_WITH_THINK = { SHORTEST_REASONING_NO_STREAMING_WITH_THINK = {
"output": "</think>This is the rest", "output": "</think>This is the rest",
"reasoning_content": "", "reasoning_content": "",
"content": "This is the rest", "content": "This is the rest",
"is_reasoning_end": True,
} }
SHORTEST_REASONING_WITH_THINK = { SHORTEST_REASONING_WITH_THINK = {
"output": "</think>This is the rest", "output": "</think>This is the rest",
"reasoning_content": None, "reasoning_content": None,
"content": "This is the rest", "content": "This is the rest",
"is_reasoning_end": True,
} }
TEST_CASES = [ TEST_CASES = [
...@@ -166,23 +184,21 @@ TEST_CASES = [ ...@@ -166,23 +184,21 @@ TEST_CASES = [
), ),
] ]
# Global tokenizer initialization to avoid repeated loading
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
tokenizer.add_tokens([start_token, end_token])
@pytest.mark.parametrize("streaming, param_dict", TEST_CASES) @pytest.mark.parametrize("streaming, param_dict", TEST_CASES)
def test_reasoning( def test_reasoning(
streaming: bool, streaming: bool,
param_dict: dict, param_dict: dict,
deepseek_r1_qwen_tokenizer,
): ):
output = tokenizer.tokenize(param_dict["output"]) output = deepseek_r1_qwen_tokenizer.tokenize(param_dict["output"])
# decode everything to tokens # decode everything to tokens
output_tokens: list[str] = [ output_tokens: list[str] = [
tokenizer.convert_tokens_to_string([token]) for token in output deepseek_r1_qwen_tokenizer.convert_tokens_to_string([token])
for token in output
] ]
parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser( parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(
parser_name)(tokenizer) parser_name)(deepseek_r1_qwen_tokenizer)
reasoning, content = run_reasoning_extraction(parser, reasoning, content = run_reasoning_extraction(parser,
output_tokens, output_tokens,
...@@ -190,3 +206,17 @@ def test_reasoning( ...@@ -190,3 +206,17 @@ def test_reasoning(
assert reasoning == param_dict["reasoning_content"] assert reasoning == param_dict["reasoning_content"]
assert content == param_dict["content"] assert content == param_dict["content"]
# Test is_reasoning_end
output_ids = deepseek_r1_qwen_tokenizer.convert_tokens_to_ids(output)
is_reasoning_end = parser.is_reasoning_end(output_ids)
assert is_reasoning_end == param_dict["is_reasoning_end"]
# Test extract_content
if param_dict["content"] is not None:
content = parser.extract_content_ids(output_ids)
assert content == deepseek_r1_qwen_tokenizer.convert_tokens_to_ids(
deepseek_r1_qwen_tokenizer.tokenize(param_dict["content"]))
else:
content = parser.extract_content_ids(output)
assert content == []
...@@ -2,10 +2,8 @@ ...@@ -2,10 +2,8 @@
import pytest import pytest
from transformers import AutoTokenizer from transformers import AutoTokenizer
from tests.entrypoints.openai.reasoning_parsers.utils import ( from tests.reasoning.utils import DeltaMessage, run_reasoning_extraction
DeltaMessage, run_reasoning_extraction) from vllm.reasoning import ReasoningParser, ReasoningParserManager
from vllm.entrypoints.openai.reasoning_parsers import (ReasoningParser,
ReasoningParserManager)
parser_name = "granite" parser_name = "granite"
START_REASONING = "Here is my thought process:" START_REASONING = "Here is my thought process:"
......
...@@ -4,7 +4,7 @@ from typing import Optional, Union ...@@ -4,7 +4,7 @@ from typing import Optional, Union
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaMessage) DeltaMessage)
from vllm.entrypoints.openai.reasoning_parsers import ReasoningParser from vllm.reasoning import ReasoningParser
class StreamingReasoningReconstructor: class StreamingReasoningReconstructor:
......
...@@ -23,6 +23,7 @@ from vllm.executor.executor_base import ExecutorBase ...@@ -23,6 +23,7 @@ from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.plugins import load_general_plugins from vllm.plugins import load_general_plugins
from vllm.reasoning import ReasoningParserManager
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3 from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
from vllm.transformers_utils.utils import check_gguf_file from vllm.transformers_utils.utils import check_gguf_file
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
...@@ -1119,7 +1120,7 @@ class EngineArgs: ...@@ -1119,7 +1120,7 @@ class EngineArgs:
parser.add_argument( parser.add_argument(
"--reasoning-parser", "--reasoning-parser",
type=str, type=str,
choices=["deepseek_r1", "granite"], choices=list(ReasoningParserManager.reasoning_parsers),
default=None, default=None,
help= help=
"Select the reasoning parser depending on the model that you're " "Select the reasoning parser depending on the model that you're "
......
...@@ -2080,8 +2080,9 @@ class LLMEngine: ...@@ -2080,8 +2080,9 @@ class LLMEngine:
guided_decoding.backend = guided_decoding.backend or \ guided_decoding.backend = guided_decoding.backend or \
self.decoding_config.guided_decoding_backend self.decoding_config.guided_decoding_backend
logger.debug("Reasoning backend: %s", if self.decoding_config.reasoning_backend is not None:
self.decoding_config.reasoning_backend) logger.debug("Building with reasoning backend %s",
self.decoding_config.reasoning_backend)
processor = get_local_guided_decoding_logits_processor( processor = get_local_guided_decoding_logits_processor(
guided_params=guided_decoding, guided_params=guided_decoding,
......
...@@ -68,7 +68,6 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, ...@@ -68,7 +68,6 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
TranscriptionRequest, TranscriptionRequest,
TranscriptionResponse, TranscriptionResponse,
UnloadLoRAAdapterRequest) UnloadLoRAAdapterRequest)
from vllm.entrypoints.openai.reasoning_parsers import ReasoningParserManager
# yapf: enable # yapf: enable
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
...@@ -85,6 +84,7 @@ from vllm.entrypoints.openai.serving_transcription import ( ...@@ -85,6 +84,7 @@ from vllm.entrypoints.openai.serving_transcription import (
from vllm.entrypoints.openai.tool_parsers import ToolParserManager from vllm.entrypoints.openai.tool_parsers import ToolParserManager
from vllm.entrypoints.utils import load_aware_call, with_cancellation from vllm.entrypoints.utils import load_aware_call, with_cancellation
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.reasoning import ReasoningParserManager
from vllm.transformers_utils.config import ( from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value) maybe_register_config_serialize_by_value)
from vllm.transformers_utils.tokenizer import MistralTokenizer from vllm.transformers_utils.tokenizer import MistralTokenizer
......
...@@ -23,8 +23,6 @@ from vllm.entrypoints.openai.protocol import ( ...@@ -23,8 +23,6 @@ from vllm.entrypoints.openai.protocol import (
ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage, ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage,
DeltaToolCall, ErrorResponse, FunctionCall, PromptTokenUsageInfo, DeltaToolCall, ErrorResponse, FunctionCall, PromptTokenUsageInfo,
RequestResponseMetadata, ToolCall, UsageInfo) RequestResponseMetadata, ToolCall, UsageInfo)
from vllm.entrypoints.openai.reasoning_parsers import (ReasoningParser,
ReasoningParserManager)
from vllm.entrypoints.openai.serving_engine import (OpenAIServing, from vllm.entrypoints.openai.serving_engine import (OpenAIServing,
clamp_prompt_logprobs) clamp_prompt_logprobs)
from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.serving_models import OpenAIServingModels
...@@ -33,6 +31,7 @@ from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import ( ...@@ -33,6 +31,7 @@ from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import (
MistralToolCall) MistralToolCall)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import CompletionOutput, RequestOutput from vllm.outputs import CompletionOutput, RequestOutput
from vllm.reasoning import ReasoningParser, ReasoningParserManager
from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.sequence import Logprob from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
......
...@@ -5,10 +5,10 @@ from __future__ import annotations ...@@ -5,10 +5,10 @@ from __future__ import annotations
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.guided_decoding.reasoner import get_reasoner
from vllm.model_executor.guided_decoding.utils import ( from vllm.model_executor.guided_decoding.utils import (
convert_lark_to_gbnf, grammar_is_likely_lark, convert_lark_to_gbnf, grammar_is_likely_lark,
has_lmf_unsupported_json_features, has_xgrammar_unsupported_json_features) has_lmf_unsupported_json_features, has_xgrammar_unsupported_json_features)
from vllm.reasoning import ReasoningParserManager
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
...@@ -107,7 +107,11 @@ async def get_guided_decoding_logits_processor( ...@@ -107,7 +107,11 @@ async def get_guided_decoding_logits_processor(
model_config: ModelConfig, model_config: ModelConfig,
reasoning_backend: str | None = None) -> LogitsProcessor | None: reasoning_backend: str | None = None) -> LogitsProcessor | None:
reasoner = get_reasoner(tokenizer, reasoning_backend) reasoner = None
if reasoning_backend is not None:
reasoner_class = ReasoningParserManager.get_reasoning_parser(
reasoning_backend)
reasoner = reasoner_class(tokenizer)
guided_params = maybe_backend_fallback(guided_params) guided_params = maybe_backend_fallback(guided_params)
...@@ -146,8 +150,11 @@ def get_local_guided_decoding_logits_processor( ...@@ -146,8 +150,11 @@ def get_local_guided_decoding_logits_processor(
reasoning_backend: str | None = None) -> LogitsProcessor | None: reasoning_backend: str | None = None) -> LogitsProcessor | None:
guided_params = maybe_backend_fallback(guided_params) guided_params = maybe_backend_fallback(guided_params)
# Get the reasoner if needed, it will be None if reasoning_ reasoner = None
reasoner = get_reasoner(tokenizer, reasoning_backend) if reasoning_backend is not None:
reasoner_class = ReasoningParserManager.get_reasoning_parser(
reasoning_backend)
reasoner = reasoner_class(tokenizer)
# CFG grammar not supported by LMFE, so we use outlines instead # CFG grammar not supported by LMFE, so we use outlines instead
if guided_params.backend_name == 'outlines': if guided_params.backend_name == 'outlines':
......
...@@ -12,7 +12,7 @@ from transformers import PreTrainedTokenizerBase ...@@ -12,7 +12,7 @@ from transformers import PreTrainedTokenizerBase
from vllm.model_executor.guided_decoding.outlines_logits_processors import ( from vllm.model_executor.guided_decoding.outlines_logits_processors import (
CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor) CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor)
from vllm.model_executor.guided_decoding.reasoner import Reasoner from vllm.reasoning import ReasoningParser
from vllm.sampling_params import GuidedDecodingParams from vllm.sampling_params import GuidedDecodingParams
...@@ -61,7 +61,7 @@ _MAX_THREADPOOL_WORKERS = 16 ...@@ -61,7 +61,7 @@ _MAX_THREADPOOL_WORKERS = 16
async def get_outlines_guided_decoding_logits_processor( async def get_outlines_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams, guided_params: GuidedDecodingParams,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
reasoner: Optional[Reasoner], reasoner: Optional[ReasoningParser],
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor, ) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
None]: None]:
""" """
...@@ -92,7 +92,7 @@ async def get_outlines_guided_decoding_logits_processor( ...@@ -92,7 +92,7 @@ async def get_outlines_guided_decoding_logits_processor(
def get_local_outlines_guided_decoding_logits_processor( def get_local_outlines_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams, guided_params: GuidedDecodingParams,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
reasoner: Optional[Reasoner], reasoner: Optional[ReasoningParser],
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor, ) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
None]: None]:
""" """
...@@ -141,7 +141,7 @@ def _get_logits_processor( ...@@ -141,7 +141,7 @@ def _get_logits_processor(
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
mode: GuidedDecodingMode, mode: GuidedDecodingMode,
whitespace_pattern: Union[str, None], whitespace_pattern: Union[str, None],
reasoner: Optional[Reasoner], reasoner: Optional[ReasoningParser],
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor]: ) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor]:
if mode == GuidedDecodingMode.JSON: if mode == GuidedDecodingMode.JSON:
return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern, return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern,
......
...@@ -34,8 +34,8 @@ from transformers import PreTrainedTokenizerBase ...@@ -34,8 +34,8 @@ from transformers import PreTrainedTokenizerBase
import vllm.envs as envs import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.guided_decoding.reasoner import Reasoner
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.reasoning import ReasoningParser
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -49,9 +49,9 @@ else: ...@@ -49,9 +49,9 @@ else:
class BaseLogitsProcessor: class BaseLogitsProcessor:
def __init__(self, guide: Guide, reasoner: Optional[Reasoner]): def __init__(self, guide: Guide, reasoner: Optional[ReasoningParser]):
self._guide: Guide = guide self._guide: Guide = guide
self._reasoner: Optional[Reasoner] = reasoner self._reasoner: Optional[ReasoningParser] = reasoner
# CFGState is used for the FSM state for CFGGuide # CFGState is used for the FSM state for CFGGuide
self._fsm_state: DefaultDict[int, Union[int, self._fsm_state: DefaultDict[int, Union[int,
CFGState]] = defaultdict(int) CFGState]] = defaultdict(int)
...@@ -69,7 +69,7 @@ class BaseLogitsProcessor: ...@@ -69,7 +69,7 @@ class BaseLogitsProcessor:
# Remove the reasoning tokens from the input_ids # Remove the reasoning tokens from the input_ids
# We need this because our implementation relies on the # We need this because our implementation relies on the
# hash of the input_ids to store the FSM state. # hash of the input_ids to store the FSM state.
input_ids = self._reasoner.extract_content(input_ids) input_ids = self._reasoner.extract_content_ids(input_ids)
seq_id = hash(tuple(input_ids)) seq_id = hash(tuple(input_ids))
...@@ -142,7 +142,7 @@ class RegexLogitsProcessor(BaseLogitsProcessor): ...@@ -142,7 +142,7 @@ class RegexLogitsProcessor(BaseLogitsProcessor):
self, self,
regex_string: str, regex_string: str,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
reasoner: Optional[Reasoner], reasoner: Optional[ReasoningParser],
): ):
"""Compile the FSM that drives the regex-structured generation. """Compile the FSM that drives the regex-structured generation.
...@@ -163,7 +163,7 @@ class JSONLogitsProcessor(RegexLogitsProcessor): ...@@ -163,7 +163,7 @@ class JSONLogitsProcessor(RegexLogitsProcessor):
def __init__(self, schema: Union[str, Dict, BaseModel], def __init__(self, schema: Union[str, Dict, BaseModel],
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
whitespace_pattern: Union[str, None], whitespace_pattern: Union[str, None],
reasoner: Optional[Reasoner]): reasoner: Optional[ReasoningParser]):
"""Compile the FSM that drives the JSON-guided generation. """Compile the FSM that drives the JSON-guided generation.
Parameters Parameters
...@@ -203,7 +203,7 @@ class CFGLogitsProcessor(BaseLogitsProcessor): ...@@ -203,7 +203,7 @@ class CFGLogitsProcessor(BaseLogitsProcessor):
return CFGGuide(cfg, tokenizer) return CFGGuide(cfg, tokenizer)
def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase, def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase,
reasoner: Optional[Reasoner]): reasoner: Optional[ReasoningParser]):
"""Compile the FSM that drives the context free grammar generation. """Compile the FSM that drives the context free grammar generation.
Parameters Parameters
......
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
from transformers import PreTrainedTokenizer
from vllm.model_executor.guided_decoding.reasoner.reasoner import Reasoner
@dataclass
class DeepSeekReasoner(Reasoner):
"""
Reasoner for DeepSeek R series models.
"""
start_token_id: int
end_token_id: int
start_token: str = "<think>"
end_token: str = "</think>"
@classmethod
def from_tokenizer(cls, tokenizer: PreTrainedTokenizer) -> Reasoner:
return cls(start_token_id=tokenizer.encode(
"<think>", add_special_tokens=False)[0],
end_token_id=tokenizer.encode("</think>",
add_special_tokens=False)[0])
def is_reasoning_end(self, input_ids: list[int]) -> bool:
return self.end_token_id in input_ids
def extract_content(self, input_ids: list[int]) -> list[int]:
"""
Extract the content after the end tokens
"""
if self.end_token_id not in input_ids or \
input_ids.index(self.end_token_id) + 1 == len(input_ids):
return []
else:
return input_ids[input_ids.index(self.end_token_id) + 1:]
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
from transformers import PreTrainedTokenizer
@dataclass
class Reasoner(ABC):
@abstractmethod
def from_tokenizer(cls, tokenizer: PreTrainedTokenizer) -> Reasoner:
pass
@abstractmethod
def is_reasoning_end(self, input_ids: list[int]) -> bool:
pass
@abstractmethod
def extract_content(self, input_ids: list[int]) -> list[int]:
pass
...@@ -27,7 +27,7 @@ if TYPE_CHECKING: ...@@ -27,7 +27,7 @@ if TYPE_CHECKING:
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.model_executor.guided_decoding.reasoner import Reasoner from vllm.reasoning import ReasoningParser
from vllm.sampling_params import GuidedDecodingParams from vllm.sampling_params import GuidedDecodingParams
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -37,7 +37,7 @@ def get_local_xgrammar_guided_decoding_logits_processor( ...@@ -37,7 +37,7 @@ def get_local_xgrammar_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams, guided_params: GuidedDecodingParams,
tokenizer: PreTrainedTokenizer, tokenizer: PreTrainedTokenizer,
model_config: ModelConfig, model_config: ModelConfig,
reasoner: Reasoner | None, reasoner: ReasoningParser | None,
max_threads: int = 8): max_threads: int = 8):
config = GrammarConfig.from_guided_params(guided_params=guided_params, config = GrammarConfig.from_guided_params(guided_params=guided_params,
model_config=model_config, model_config=model_config,
...@@ -280,7 +280,7 @@ class GrammarConfig: ...@@ -280,7 +280,7 @@ class GrammarConfig:
class XGrammarLogitsProcessor: class XGrammarLogitsProcessor:
"""Wrapper class to support pickle protocol""" """Wrapper class to support pickle protocol"""
config: GrammarConfig config: GrammarConfig
reasoner: Reasoner | None = None reasoner: ReasoningParser | None = None
ctx: xgr.CompiledGrammar | None = None ctx: xgr.CompiledGrammar | None = None
tokenizer_info: xgr.TokenizerInfo = None # type: ignore[assignment] tokenizer_info: xgr.TokenizerInfo = None # type: ignore[assignment]
......
...@@ -17,7 +17,7 @@ logger = init_logger(__name__) ...@@ -17,7 +17,7 @@ logger = init_logger(__name__)
class ReasoningParser: class ReasoningParser:
""" """
Abstract reasoning parser class that should not be used directly. Abstract reasoning parser class that should not be used directly.
Provided and methods should be used in derived classes. Provided and methods should be used in derived classes.
It is used to extract reasoning content from the model output. It is used to extract reasoning content from the model output.
...@@ -32,6 +32,36 @@ class ReasoningParser: ...@@ -32,6 +32,36 @@ class ReasoningParser:
# whereas all tokenizers have .get_vocab() # whereas all tokenizers have .get_vocab()
return self.model_tokenizer.get_vocab() return self.model_tokenizer.get_vocab()
@abstractmethod
def is_reasoning_end(self, input_ids: list[int]) -> bool:
"""
Check if the reasoning content ends in the input_ids.
It is used in structured engines like `xgrammar` to check if the
reasoning content ends in the model output.
Parameters:
input_ids: list[int]
The input_ids of the model output.
Returns:
bool
True if the reasoning content ends in the input_ids.
"""
@abstractmethod
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
"""
Extract content token ids from the input_ids.
Parameters:
input_ids: list[int]
The input_ids of the model output.
Returns:
list[int]
The extracted content from the input_ids.
"""
@abstractmethod
def extract_reasoning_content( def extract_reasoning_content(
self, model_output: str, request: ChatCompletionRequest self, model_output: str, request: ChatCompletionRequest
) -> tuple[Optional[str], Optional[str]]: ) -> tuple[Optional[str], Optional[str]]:
...@@ -53,10 +83,7 @@ class ReasoningParser: ...@@ -53,10 +83,7 @@ class ReasoningParser:
A tuple containing the reasoning content and the content. A tuple containing the reasoning content and the content.
""" """
raise NotImplementedError( @abstractmethod
"AbstractReasoningParser.extract_reasoning_calls "
"has not been implemented!")
def extract_reasoning_content_streaming( def extract_reasoning_content_streaming(
self, self,
previous_text: str, previous_text: str,
...@@ -73,43 +100,6 @@ class ReasoningParser: ...@@ -73,43 +100,6 @@ class ReasoningParser:
the current tokens/diffs, but also the information about what has the current tokens/diffs, but also the information about what has
previously been parsed and extracted (see constructor) previously been parsed and extracted (see constructor)
""" """
raise NotImplementedError(
"AbstractReasoningParser.extract_reasoning_content_streaming "
"has not been implemented!")
# TODO: need to rebase by PR #14428
@abstractmethod
def is_reasoning_end(self, input_ids: list[int]) -> bool:
"""
Check if the reasoning content ends in the input_ids.
Parameters:
input_ids: list[int]
The input_ids of the model output.
Returns:
bool
True if the reasoning content ends in the input_ids.
"""
raise NotImplementedError(
"AbstractReasoningParser.is_reasoning_end has"
"not been implemented!")
# TODO: need to rebase by PR #14428
@abstractmethod
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
"""
Extract content token ids from the input_ids.
Parameters:
input_ids: list[int]
The input_ids of the model output.
Returns:
list[int]
The extracted content from the input_ids.
"""
raise NotImplementedError(
"AbstractReasoningParser.extract_content_ids has"
" not been implemented!")
class ReasoningParserManager: class ReasoningParserManager:
...@@ -125,14 +115,16 @@ class ReasoningParserManager: ...@@ -125,14 +115,16 @@ class ReasoningParserManager:
if name in cls.reasoning_parsers: if name in cls.reasoning_parsers:
return cls.reasoning_parsers[name] return cls.reasoning_parsers[name]
raise KeyError(f"reasoning helper: '{name}' not found in " raise KeyError(
"reasoning_parsers") f"reasoning helper: '{name}' not found in reasoning_parsers")
@classmethod @classmethod
def _register_module(cls, def _register_module(
module: type, cls,
module_name: Optional[Union[str, list[str]]] = None, module: type,
force: bool = True) -> None: module_name: Optional[Union[str, list[str]]] = None,
force: bool = True,
) -> None:
if not issubclass(module, ReasoningParser): if not issubclass(module, ReasoningParser):
raise TypeError("module must be subclass of ReasoningParser, " raise TypeError("module must be subclass of ReasoningParser, "
f"but got {type(module)}") f"but got {type(module)}")
...@@ -149,13 +141,14 @@ class ReasoningParserManager: ...@@ -149,13 +141,14 @@ class ReasoningParserManager:
@classmethod @classmethod
def register_module( def register_module(
cls, cls,
name: Optional[Union[str, list[str]]] = None, name: Optional[Union[str, list[str]]] = None,
force: bool = True, force: bool = True,
module: Union[type, None] = None) -> Union[type, Callable]: module: Union[type, None] = None,
) -> Union[type, Callable]:
""" """
Register module with the given name or name list. it can be used as a Register module with the given name or name list. it can be used as a
decoder(with module as None) or normal function(with module as not decoder(with module as None) or normal function(with module as not
None). None).
""" """
if not isinstance(force, bool): if not isinstance(force, bool):
...@@ -183,7 +176,7 @@ class ReasoningParserManager: ...@@ -183,7 +176,7 @@ class ReasoningParserManager:
@classmethod @classmethod
def import_reasoning_parser(cls, plugin_path: str) -> None: def import_reasoning_parser(cls, plugin_path: str) -> None:
""" """
Import a user-defined reasoning parser by the path Import a user-defined reasoning parser by the path
of the reasoning parser define file. of the reasoning parser define file.
""" """
module_name = os.path.splitext(os.path.basename(plugin_path))[0] module_name = os.path.splitext(os.path.basename(plugin_path))[0]
......
...@@ -8,9 +8,8 @@ from transformers import PreTrainedTokenizerBase ...@@ -8,9 +8,8 @@ from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaMessage) DeltaMessage)
from vllm.entrypoints.openai.reasoning_parsers.abs_reasoning_parsers import (
ReasoningParser, ReasoningParserManager)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.reasoning import ReasoningParser, ReasoningParserManager
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -20,43 +19,45 @@ class DeepSeekR1ReasoningParser(ReasoningParser): ...@@ -20,43 +19,45 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
""" """
Reasoning parser for DeepSeek R1 model. Reasoning parser for DeepSeek R1 model.
The DeepSeek R1 model uses <think>...</think> tokens to denote reasoning The DeepSeek R1 model uses <think>...</think> tokens to denote reasoning
text. This parser extracts the reasoning content from the model output. text. This parser extracts the reasoning content from the model output.
""" """
start_token_id: int
end_token_id: int
start_token: str = "<think>"
end_token: str = "</think>"
def __init__(self, tokenizer: PreTrainedTokenizerBase): def __init__(self, tokenizer: PreTrainedTokenizerBase):
super().__init__(tokenizer) super().__init__(tokenizer)
self.think_start_token = "<think>"
self.think_end_token = "</think>"
self.reasoning_regex = re.compile( self.reasoning_regex = re.compile(
rf"{self.think_start_token}(.*?){self.think_end_token}", re.DOTALL) rf"{self.start_token}(.*?){self.end_token}", re.DOTALL)
if not self.model_tokenizer: if not self.model_tokenizer:
raise ValueError( raise ValueError(
"The model tokenizer must be passed to the ReasoningParser " "The model tokenizer must be passed to the ReasoningParser "
"constructor during construction.") "constructor during construction.")
self.think_start_token_id = self.vocab.get(self.think_start_token) self.start_token_id = self.vocab.get(self.start_token)
self.think_end_token_id = self.vocab.get(self.think_end_token) self.end_token_id = self.vocab.get(self.end_token)
if (self.think_start_token_id is None if self.start_token_id is None or self.end_token_id is None:
or self.think_end_token_id is None):
raise RuntimeError( raise RuntimeError(
"DeepSeek R1 reasoning parser could not locate think start/end " "DeepSeek R1 reasoning parser could not locate think start/end "
"tokens in the tokenizer!") "tokens in the tokenizer!")
# TODO: need to rebase by PR #14428
def is_reasoning_end(self, input_ids: list[int]) -> bool: def is_reasoning_end(self, input_ids: list[int]) -> bool:
return self.think_end_token_id in input_ids return self.end_token_id in input_ids
def extract_content_ids(self, input_ids: list[int]) -> list[int]: def extract_content_ids(self, input_ids: list[int]) -> list[int]:
""" """
Extract the content after the end tokens Extract the content after the end tokens
""" """
if self.think_end_token_id not in input_ids[:-1]: if self.end_token_id not in input_ids[:-1]:
return [] return []
else: else:
return input_ids[input_ids.index(self.think_end_token_id) + 1:] return input_ids[input_ids.index(self.end_token_id) + 1:]
def extract_reasoning_content_streaming( def extract_reasoning_content_streaming(
self, self,
...@@ -77,22 +78,24 @@ class DeepSeekR1ReasoningParser(ReasoningParser): ...@@ -77,22 +78,24 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
""" """
# Skip single special tokens # Skip single special tokens
if len(delta_token_ids) == 1 and (delta_token_ids[0] in [ if len(delta_token_ids) == 1 and (delta_token_ids[0] in [
self.think_start_token_id, self.think_end_token_id self.start_token_id, self.end_token_id
]): ]):
return None return None
# Check if <think> is present in previous or delta. # Check if <think> is present in previous or delta.
# Keep compatibility with models that don't generate <think> tokens. # Keep compatibility with models that don't generate <think> tokens.
if self.think_start_token_id in previous_token_ids: if self.start_token_id in previous_token_ids:
if self.think_end_token_id in delta_token_ids: if self.end_token_id in delta_token_ids:
# <think> in previous, </think> in delta, # <think> in previous, </think> in delta,
# extract reasoning content # extract reasoning content
end_index = delta_text.find(self.think_end_token) end_index = delta_text.find(self.end_token)
reasoning_content = delta_text[:end_index] reasoning_content = delta_text[:end_index]
content = delta_text[end_index + len(self.think_end_token):] content = delta_text[end_index + len(self.end_token):]
return DeltaMessage(reasoning_content=reasoning_content, return DeltaMessage(
content=content if content else None) reasoning_content=reasoning_content,
elif self.think_end_token_id in previous_token_ids: content=content if content else None,
)
elif self.end_token_id in previous_token_ids:
# <think> in previous, </think> in previous, # <think> in previous, </think> in previous,
# reasoning content continues # reasoning content continues
return DeltaMessage(content=delta_text) return DeltaMessage(content=delta_text)
...@@ -100,17 +103,18 @@ class DeepSeekR1ReasoningParser(ReasoningParser): ...@@ -100,17 +103,18 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
# <think> in previous, no </think> in previous or delta, # <think> in previous, no </think> in previous or delta,
# reasoning content continues # reasoning content continues
return DeltaMessage(reasoning_content=delta_text) return DeltaMessage(reasoning_content=delta_text)
elif self.think_start_token_id in delta_token_ids: elif self.start_token_id in delta_token_ids:
if self.think_end_token_id in delta_token_ids: if self.end_token_id in delta_token_ids:
# <think> in delta, </think> in delta, extract reasoning content # <think> in delta, </think> in delta, extract reasoning content
start_index = delta_text.find(self.think_start_token) start_index = delta_text.find(self.start_token)
end_index = delta_text.find(self.think_end_token) end_index = delta_text.find(self.end_token)
reasoning_content = delta_text[start_index + reasoning_content = delta_text[start_index +
len(self.think_start_token len(self.start_token):end_index]
):end_index] content = delta_text[end_index + len(self.end_token):]
content = delta_text[end_index + len(self.think_end_token):] return DeltaMessage(
return DeltaMessage(reasoning_content=reasoning_content, reasoning_content=reasoning_content,
content=content if content else None) content=content if content else None,
)
else: else:
# <think> in delta, no </think> in delta, # <think> in delta, no </think> in delta,
# reasoning content continues # reasoning content continues
...@@ -119,15 +123,17 @@ class DeepSeekR1ReasoningParser(ReasoningParser): ...@@ -119,15 +123,17 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
# No <think> in previous or delta, also need to check for </think>. # No <think> in previous or delta, also need to check for </think>.
# Because the model may have generated </think> without <think> # Because the model may have generated </think> without <think>
# Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f # Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f
if self.think_end_token_id in delta_token_ids: if self.end_token_id in delta_token_ids:
# </think> in delta with more tokens, # </think> in delta with more tokens,
# extract reasoning content and content # extract reasoning content and content
end_index = delta_text.find(self.think_end_token) end_index = delta_text.find(self.end_token)
reasoning_content = delta_text[:end_index] reasoning_content = delta_text[:end_index]
content = delta_text[end_index + len(self.think_end_token):] content = delta_text[end_index + len(self.end_token):]
return DeltaMessage(reasoning_content=reasoning_content, return DeltaMessage(
content=content if content else None) reasoning_content=reasoning_content,
elif self.think_end_token_id in previous_token_ids: content=content if content else None,
)
elif self.end_token_id in previous_token_ids:
# </think> in previous, thinking content ends # </think> in previous, thinking content ends
return DeltaMessage(content=delta_text) return DeltaMessage(content=delta_text)
else: else:
...@@ -137,22 +143,20 @@ class DeepSeekR1ReasoningParser(ReasoningParser): ...@@ -137,22 +143,20 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
def extract_reasoning_content( def extract_reasoning_content(
self, model_output: str, request: ChatCompletionRequest self, model_output: str, request: ChatCompletionRequest
) -> tuple[Optional[str], Optional[str]]: ) -> tuple[Optional[str], Optional[str]]:
# DeepSeek R1 doesn't generate <think> now. # DeepSeek R1 doesn't generate <think> now.
# Thus we assume the reasoning content is always at the start. # Thus we assume the reasoning content is always at the start.
# Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f # Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f
if self.think_end_token not in model_output: if self.end_token not in model_output:
return model_output, None return model_output, None
else: else:
# Add a start token if it's missing to keep compatibility. # Add a start token if it's missing to keep compatibility.
if self.think_start_token not in model_output: if self.start_token not in model_output:
model_output = f"{self.think_start_token}{model_output}" model_output = f"{self.start_token}{model_output}"
# Use a regex to find the reasoning content # Use a regex to find the reasoning content
reasoning_content = self.reasoning_regex.findall(model_output)[0] reasoning_content = self.reasoning_regex.findall(model_output)[0]
end_index = len( end_index = len(
f"{self.think_start_token}{reasoning_content}{self.think_end_token}" f"{self.start_token}{reasoning_content}{self.end_token}")
)
final_output = model_output[end_index:] final_output = model_output[end_index:]
if len(final_output) == 0: if len(final_output) == 0:
......
...@@ -8,9 +8,8 @@ from transformers import PreTrainedTokenizerBase ...@@ -8,9 +8,8 @@ from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaMessage) DeltaMessage)
from vllm.entrypoints.openai.reasoning_parsers.abs_reasoning_parsers import (
ReasoningParser, ReasoningParserManager)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.reasoning import ReasoningParser, ReasoningParserManager
logger = init_logger(__name__) logger = init_logger(__name__)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment