Unverified Commit 32b14baf authored by Ce Gao's avatar Ce Gao Committed by GitHub
Browse files

[Refactor][Frontend] Keep all logic about reasoning into one class (#14428)


Signed-off-by: default avatarCe Gao <cegao@tensorchord.ai>
parent 2d9045fc
...@@ -3,74 +3,92 @@ ...@@ -3,74 +3,92 @@
import pytest import pytest
from transformers import AutoTokenizer from transformers import AutoTokenizer
from tests.entrypoints.openai.reasoning_parsers.utils import ( from tests.reasoning.utils import run_reasoning_extraction
run_reasoning_extraction) from vllm.reasoning import ReasoningParser, ReasoningParserManager
from vllm.entrypoints.openai.reasoning_parsers import (ReasoningParser,
ReasoningParserManager)
parser_name = "deepseek_r1" parser_name = "deepseek_r1"
start_token = "<think>" start_token = "<think>"
end_token = "</think>" end_token = "</think>"
REASONING_MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
@pytest.fixture(scope="module")
def deepseek_r1_qwen_tokenizer():
return AutoTokenizer.from_pretrained(REASONING_MODEL_NAME)
SIMPLE_REASONING = { SIMPLE_REASONING = {
"output": "This is a reasoning section</think>This is the rest", "output": "This is a reasoning section</think>This is the rest",
"reasoning_content": "This is a reasoning section", "reasoning_content": "This is a reasoning section",
"content": "This is the rest", "content": "This is the rest",
"is_reasoning_end": True,
} }
COMPLETE_REASONING = { COMPLETE_REASONING = {
"output": "This is a reasoning section</think>", "output": "This is a reasoning section</think>",
"reasoning_content": "This is a reasoning section", "reasoning_content": "This is a reasoning section",
"content": None, "content": None,
"is_reasoning_end": True,
} }
NO_CONTENT = { NO_CONTENT = {
"output": "This is content", "output": "This is content",
"reasoning_content": "This is content", "reasoning_content": "This is content",
"content": None, "content": None,
"is_reasoning_end": False,
} }
NO_REASONING_STREAMING = { NO_REASONING_STREAMING = {
"output": "This is a reasoning section", "output": "This is a reasoning section",
"reasoning_content": "This is a reasoning section", "reasoning_content": "This is a reasoning section",
"content": None, "content": None,
"is_reasoning_end": False,
} }
MULTIPLE_LINES = { MULTIPLE_LINES = {
"output": "This\nThat</think>This is the rest\nThat", "output": "This\nThat</think>This is the rest\nThat",
"reasoning_content": "This\nThat", "reasoning_content": "This\nThat",
"content": "This is the rest\nThat", "content": "This is the rest\nThat",
"is_reasoning_end": True,
} }
SHORTEST_REASONING_NO_STREAMING = { SHORTEST_REASONING_NO_STREAMING = {
"output": "</think>This is the rest", "output": "</think>This is the rest",
"reasoning_content": "", "reasoning_content": "",
"content": "This is the rest", "content": "This is the rest",
"is_reasoning_end": True,
} }
SHORTEST_REASONING = { SHORTEST_REASONING = {
"output": "</think>This is the rest", "output": "</think>This is the rest",
"reasoning_content": None, "reasoning_content": None,
"content": "This is the rest", "content": "This is the rest",
"is_reasoning_end": True,
} }
REASONING_WITH_THINK = { REASONING_WITH_THINK = {
"output": "<think>This is a reasoning section</think>This is the rest", "output": "<think>This is a reasoning section</think>This is the rest",
"reasoning_content": "This is a reasoning section", "reasoning_content": "This is a reasoning section",
"content": "This is the rest", "content": "This is the rest",
"is_reasoning_end": True,
} }
COMPLETE_REASONING_WITH_THINK = { COMPLETE_REASONING_WITH_THINK = {
"output": "<think>This is a reasoning section</think>", "output": "<think>This is a reasoning section</think>",
"reasoning_content": "This is a reasoning section", "reasoning_content": "This is a reasoning section",
"content": None, "content": None,
"is_reasoning_end": True,
} }
MULTIPLE_LINES_WITH_THINK = { MULTIPLE_LINES_WITH_THINK = {
"output": "<think>This\nThat</think>This is the rest\nThat", "output": "<think>This\nThat</think>This is the rest\nThat",
"reasoning_content": "This\nThat", "reasoning_content": "This\nThat",
"content": "This is the rest\nThat", "content": "This is the rest\nThat",
"is_reasoning_end": True,
} }
SHORTEST_REASONING_NO_STREAMING_WITH_THINK = { SHORTEST_REASONING_NO_STREAMING_WITH_THINK = {
"output": "</think>This is the rest", "output": "</think>This is the rest",
"reasoning_content": "", "reasoning_content": "",
"content": "This is the rest", "content": "This is the rest",
"is_reasoning_end": True,
} }
SHORTEST_REASONING_WITH_THINK = { SHORTEST_REASONING_WITH_THINK = {
"output": "</think>This is the rest", "output": "</think>This is the rest",
"reasoning_content": None, "reasoning_content": None,
"content": "This is the rest", "content": "This is the rest",
"is_reasoning_end": True,
} }
TEST_CASES = [ TEST_CASES = [
...@@ -166,23 +184,21 @@ TEST_CASES = [ ...@@ -166,23 +184,21 @@ TEST_CASES = [
), ),
] ]
# Global tokenizer initialization to avoid repeated loading
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
tokenizer.add_tokens([start_token, end_token])
@pytest.mark.parametrize("streaming, param_dict", TEST_CASES) @pytest.mark.parametrize("streaming, param_dict", TEST_CASES)
def test_reasoning( def test_reasoning(
streaming: bool, streaming: bool,
param_dict: dict, param_dict: dict,
deepseek_r1_qwen_tokenizer,
): ):
output = tokenizer.tokenize(param_dict["output"]) output = deepseek_r1_qwen_tokenizer.tokenize(param_dict["output"])
# decode everything to tokens # decode everything to tokens
output_tokens: list[str] = [ output_tokens: list[str] = [
tokenizer.convert_tokens_to_string([token]) for token in output deepseek_r1_qwen_tokenizer.convert_tokens_to_string([token])
for token in output
] ]
parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser( parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(
parser_name)(tokenizer) parser_name)(deepseek_r1_qwen_tokenizer)
reasoning, content = run_reasoning_extraction(parser, reasoning, content = run_reasoning_extraction(parser,
output_tokens, output_tokens,
...@@ -190,3 +206,17 @@ def test_reasoning( ...@@ -190,3 +206,17 @@ def test_reasoning(
assert reasoning == param_dict["reasoning_content"] assert reasoning == param_dict["reasoning_content"]
assert content == param_dict["content"] assert content == param_dict["content"]
# Test is_reasoning_end
output_ids = deepseek_r1_qwen_tokenizer.convert_tokens_to_ids(output)
is_reasoning_end = parser.is_reasoning_end(output_ids)
assert is_reasoning_end == param_dict["is_reasoning_end"]
# Test extract_content
if param_dict["content"] is not None:
content = parser.extract_content_ids(output_ids)
assert content == deepseek_r1_qwen_tokenizer.convert_tokens_to_ids(
deepseek_r1_qwen_tokenizer.tokenize(param_dict["content"]))
else:
content = parser.extract_content_ids(output)
assert content == []
...@@ -2,10 +2,8 @@ ...@@ -2,10 +2,8 @@
import pytest import pytest
from transformers import AutoTokenizer from transformers import AutoTokenizer
from tests.entrypoints.openai.reasoning_parsers.utils import ( from tests.reasoning.utils import DeltaMessage, run_reasoning_extraction
DeltaMessage, run_reasoning_extraction) from vllm.reasoning import ReasoningParser, ReasoningParserManager
from vllm.entrypoints.openai.reasoning_parsers import (ReasoningParser,
ReasoningParserManager)
parser_name = "granite" parser_name = "granite"
START_REASONING = "Here is my thought process:" START_REASONING = "Here is my thought process:"
......
...@@ -4,7 +4,7 @@ from typing import Optional, Union ...@@ -4,7 +4,7 @@ from typing import Optional, Union
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaMessage) DeltaMessage)
from vllm.entrypoints.openai.reasoning_parsers import ReasoningParser from vllm.reasoning import ReasoningParser
class StreamingReasoningReconstructor: class StreamingReasoningReconstructor:
......
...@@ -23,6 +23,7 @@ from vllm.executor.executor_base import ExecutorBase ...@@ -23,6 +23,7 @@ from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.plugins import load_general_plugins from vllm.plugins import load_general_plugins
from vllm.reasoning import ReasoningParserManager
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3 from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
from vllm.transformers_utils.utils import check_gguf_file from vllm.transformers_utils.utils import check_gguf_file
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
...@@ -1119,7 +1120,7 @@ class EngineArgs: ...@@ -1119,7 +1120,7 @@ class EngineArgs:
parser.add_argument( parser.add_argument(
"--reasoning-parser", "--reasoning-parser",
type=str, type=str,
choices=["deepseek_r1", "granite"], choices=list(ReasoningParserManager.reasoning_parsers),
default=None, default=None,
help= help=
"Select the reasoning parser depending on the model that you're " "Select the reasoning parser depending on the model that you're "
......
...@@ -2080,7 +2080,8 @@ class LLMEngine: ...@@ -2080,7 +2080,8 @@ class LLMEngine:
guided_decoding.backend = guided_decoding.backend or \ guided_decoding.backend = guided_decoding.backend or \
self.decoding_config.guided_decoding_backend self.decoding_config.guided_decoding_backend
logger.debug("Reasoning backend: %s", if self.decoding_config.reasoning_backend is not None:
logger.debug("Building with reasoning backend %s",
self.decoding_config.reasoning_backend) self.decoding_config.reasoning_backend)
processor = get_local_guided_decoding_logits_processor( processor = get_local_guided_decoding_logits_processor(
......
...@@ -68,7 +68,6 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, ...@@ -68,7 +68,6 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
TranscriptionRequest, TranscriptionRequest,
TranscriptionResponse, TranscriptionResponse,
UnloadLoRAAdapterRequest) UnloadLoRAAdapterRequest)
from vllm.entrypoints.openai.reasoning_parsers import ReasoningParserManager
# yapf: enable # yapf: enable
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
...@@ -85,6 +84,7 @@ from vllm.entrypoints.openai.serving_transcription import ( ...@@ -85,6 +84,7 @@ from vllm.entrypoints.openai.serving_transcription import (
from vllm.entrypoints.openai.tool_parsers import ToolParserManager from vllm.entrypoints.openai.tool_parsers import ToolParserManager
from vllm.entrypoints.utils import load_aware_call, with_cancellation from vllm.entrypoints.utils import load_aware_call, with_cancellation
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.reasoning import ReasoningParserManager
from vllm.transformers_utils.config import ( from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value) maybe_register_config_serialize_by_value)
from vllm.transformers_utils.tokenizer import MistralTokenizer from vllm.transformers_utils.tokenizer import MistralTokenizer
......
...@@ -23,8 +23,6 @@ from vllm.entrypoints.openai.protocol import ( ...@@ -23,8 +23,6 @@ from vllm.entrypoints.openai.protocol import (
ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage, ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage,
DeltaToolCall, ErrorResponse, FunctionCall, PromptTokenUsageInfo, DeltaToolCall, ErrorResponse, FunctionCall, PromptTokenUsageInfo,
RequestResponseMetadata, ToolCall, UsageInfo) RequestResponseMetadata, ToolCall, UsageInfo)
from vllm.entrypoints.openai.reasoning_parsers import (ReasoningParser,
ReasoningParserManager)
from vllm.entrypoints.openai.serving_engine import (OpenAIServing, from vllm.entrypoints.openai.serving_engine import (OpenAIServing,
clamp_prompt_logprobs) clamp_prompt_logprobs)
from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.serving_models import OpenAIServingModels
...@@ -33,6 +31,7 @@ from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import ( ...@@ -33,6 +31,7 @@ from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import (
MistralToolCall) MistralToolCall)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import CompletionOutput, RequestOutput from vllm.outputs import CompletionOutput, RequestOutput
from vllm.reasoning import ReasoningParser, ReasoningParserManager
from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.sequence import Logprob from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
......
...@@ -5,10 +5,10 @@ from __future__ import annotations ...@@ -5,10 +5,10 @@ from __future__ import annotations
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.guided_decoding.reasoner import get_reasoner
from vllm.model_executor.guided_decoding.utils import ( from vllm.model_executor.guided_decoding.utils import (
convert_lark_to_gbnf, grammar_is_likely_lark, convert_lark_to_gbnf, grammar_is_likely_lark,
has_lmf_unsupported_json_features, has_xgrammar_unsupported_json_features) has_lmf_unsupported_json_features, has_xgrammar_unsupported_json_features)
from vllm.reasoning import ReasoningParserManager
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
...@@ -107,7 +107,11 @@ async def get_guided_decoding_logits_processor( ...@@ -107,7 +107,11 @@ async def get_guided_decoding_logits_processor(
model_config: ModelConfig, model_config: ModelConfig,
reasoning_backend: str | None = None) -> LogitsProcessor | None: reasoning_backend: str | None = None) -> LogitsProcessor | None:
reasoner = get_reasoner(tokenizer, reasoning_backend) reasoner = None
if reasoning_backend is not None:
reasoner_class = ReasoningParserManager.get_reasoning_parser(
reasoning_backend)
reasoner = reasoner_class(tokenizer)
guided_params = maybe_backend_fallback(guided_params) guided_params = maybe_backend_fallback(guided_params)
...@@ -146,8 +150,11 @@ def get_local_guided_decoding_logits_processor( ...@@ -146,8 +150,11 @@ def get_local_guided_decoding_logits_processor(
reasoning_backend: str | None = None) -> LogitsProcessor | None: reasoning_backend: str | None = None) -> LogitsProcessor | None:
guided_params = maybe_backend_fallback(guided_params) guided_params = maybe_backend_fallback(guided_params)
# Get the reasoner if needed, it will be None if reasoning_ reasoner = None
reasoner = get_reasoner(tokenizer, reasoning_backend) if reasoning_backend is not None:
reasoner_class = ReasoningParserManager.get_reasoning_parser(
reasoning_backend)
reasoner = reasoner_class(tokenizer)
# CFG grammar not supported by LMFE, so we use outlines instead # CFG grammar not supported by LMFE, so we use outlines instead
if guided_params.backend_name == 'outlines': if guided_params.backend_name == 'outlines':
......
...@@ -12,7 +12,7 @@ from transformers import PreTrainedTokenizerBase ...@@ -12,7 +12,7 @@ from transformers import PreTrainedTokenizerBase
from vllm.model_executor.guided_decoding.outlines_logits_processors import ( from vllm.model_executor.guided_decoding.outlines_logits_processors import (
CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor) CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor)
from vllm.model_executor.guided_decoding.reasoner import Reasoner from vllm.reasoning import ReasoningParser
from vllm.sampling_params import GuidedDecodingParams from vllm.sampling_params import GuidedDecodingParams
...@@ -61,7 +61,7 @@ _MAX_THREADPOOL_WORKERS = 16 ...@@ -61,7 +61,7 @@ _MAX_THREADPOOL_WORKERS = 16
async def get_outlines_guided_decoding_logits_processor( async def get_outlines_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams, guided_params: GuidedDecodingParams,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
reasoner: Optional[Reasoner], reasoner: Optional[ReasoningParser],
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor, ) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
None]: None]:
""" """
...@@ -92,7 +92,7 @@ async def get_outlines_guided_decoding_logits_processor( ...@@ -92,7 +92,7 @@ async def get_outlines_guided_decoding_logits_processor(
def get_local_outlines_guided_decoding_logits_processor( def get_local_outlines_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams, guided_params: GuidedDecodingParams,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
reasoner: Optional[Reasoner], reasoner: Optional[ReasoningParser],
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor, ) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
None]: None]:
""" """
...@@ -141,7 +141,7 @@ def _get_logits_processor( ...@@ -141,7 +141,7 @@ def _get_logits_processor(
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
mode: GuidedDecodingMode, mode: GuidedDecodingMode,
whitespace_pattern: Union[str, None], whitespace_pattern: Union[str, None],
reasoner: Optional[Reasoner], reasoner: Optional[ReasoningParser],
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor]: ) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor]:
if mode == GuidedDecodingMode.JSON: if mode == GuidedDecodingMode.JSON:
return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern, return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern,
......
...@@ -34,8 +34,8 @@ from transformers import PreTrainedTokenizerBase ...@@ -34,8 +34,8 @@ from transformers import PreTrainedTokenizerBase
import vllm.envs as envs import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.guided_decoding.reasoner import Reasoner
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.reasoning import ReasoningParser
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -49,9 +49,9 @@ else: ...@@ -49,9 +49,9 @@ else:
class BaseLogitsProcessor: class BaseLogitsProcessor:
def __init__(self, guide: Guide, reasoner: Optional[Reasoner]): def __init__(self, guide: Guide, reasoner: Optional[ReasoningParser]):
self._guide: Guide = guide self._guide: Guide = guide
self._reasoner: Optional[Reasoner] = reasoner self._reasoner: Optional[ReasoningParser] = reasoner
# CFGState is used for the FSM state for CFGGuide # CFGState is used for the FSM state for CFGGuide
self._fsm_state: DefaultDict[int, Union[int, self._fsm_state: DefaultDict[int, Union[int,
CFGState]] = defaultdict(int) CFGState]] = defaultdict(int)
...@@ -69,7 +69,7 @@ class BaseLogitsProcessor: ...@@ -69,7 +69,7 @@ class BaseLogitsProcessor:
# Remove the reasoning tokens from the input_ids # Remove the reasoning tokens from the input_ids
# We need this because our implementation relies on the # We need this because our implementation relies on the
# hash of the input_ids to store the FSM state. # hash of the input_ids to store the FSM state.
input_ids = self._reasoner.extract_content(input_ids) input_ids = self._reasoner.extract_content_ids(input_ids)
seq_id = hash(tuple(input_ids)) seq_id = hash(tuple(input_ids))
...@@ -142,7 +142,7 @@ class RegexLogitsProcessor(BaseLogitsProcessor): ...@@ -142,7 +142,7 @@ class RegexLogitsProcessor(BaseLogitsProcessor):
self, self,
regex_string: str, regex_string: str,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
reasoner: Optional[Reasoner], reasoner: Optional[ReasoningParser],
): ):
"""Compile the FSM that drives the regex-structured generation. """Compile the FSM that drives the regex-structured generation.
...@@ -163,7 +163,7 @@ class JSONLogitsProcessor(RegexLogitsProcessor): ...@@ -163,7 +163,7 @@ class JSONLogitsProcessor(RegexLogitsProcessor):
def __init__(self, schema: Union[str, Dict, BaseModel], def __init__(self, schema: Union[str, Dict, BaseModel],
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
whitespace_pattern: Union[str, None], whitespace_pattern: Union[str, None],
reasoner: Optional[Reasoner]): reasoner: Optional[ReasoningParser]):
"""Compile the FSM that drives the JSON-guided generation. """Compile the FSM that drives the JSON-guided generation.
Parameters Parameters
...@@ -203,7 +203,7 @@ class CFGLogitsProcessor(BaseLogitsProcessor): ...@@ -203,7 +203,7 @@ class CFGLogitsProcessor(BaseLogitsProcessor):
return CFGGuide(cfg, tokenizer) return CFGGuide(cfg, tokenizer)
def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase, def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase,
reasoner: Optional[Reasoner]): reasoner: Optional[ReasoningParser]):
"""Compile the FSM that drives the context free grammar generation. """Compile the FSM that drives the context free grammar generation.
Parameters Parameters
......
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
from transformers import PreTrainedTokenizer
from vllm.model_executor.guided_decoding.reasoner.reasoner import Reasoner
@dataclass
class DeepSeekReasoner(Reasoner):
"""
Reasoner for DeepSeek R series models.
"""
start_token_id: int
end_token_id: int
start_token: str = "<think>"
end_token: str = "</think>"
@classmethod
def from_tokenizer(cls, tokenizer: PreTrainedTokenizer) -> Reasoner:
return cls(start_token_id=tokenizer.encode(
"<think>", add_special_tokens=False)[0],
end_token_id=tokenizer.encode("</think>",
add_special_tokens=False)[0])
def is_reasoning_end(self, input_ids: list[int]) -> bool:
return self.end_token_id in input_ids
def extract_content(self, input_ids: list[int]) -> list[int]:
"""
Extract the content after the end tokens
"""
if self.end_token_id not in input_ids or \
input_ids.index(self.end_token_id) + 1 == len(input_ids):
return []
else:
return input_ids[input_ids.index(self.end_token_id) + 1:]
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
from transformers import PreTrainedTokenizer
@dataclass
class Reasoner(ABC):
@abstractmethod
def from_tokenizer(cls, tokenizer: PreTrainedTokenizer) -> Reasoner:
pass
@abstractmethod
def is_reasoning_end(self, input_ids: list[int]) -> bool:
pass
@abstractmethod
def extract_content(self, input_ids: list[int]) -> list[int]:
pass
...@@ -27,7 +27,7 @@ if TYPE_CHECKING: ...@@ -27,7 +27,7 @@ if TYPE_CHECKING:
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.model_executor.guided_decoding.reasoner import Reasoner from vllm.reasoning import ReasoningParser
from vllm.sampling_params import GuidedDecodingParams from vllm.sampling_params import GuidedDecodingParams
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -37,7 +37,7 @@ def get_local_xgrammar_guided_decoding_logits_processor( ...@@ -37,7 +37,7 @@ def get_local_xgrammar_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams, guided_params: GuidedDecodingParams,
tokenizer: PreTrainedTokenizer, tokenizer: PreTrainedTokenizer,
model_config: ModelConfig, model_config: ModelConfig,
reasoner: Reasoner | None, reasoner: ReasoningParser | None,
max_threads: int = 8): max_threads: int = 8):
config = GrammarConfig.from_guided_params(guided_params=guided_params, config = GrammarConfig.from_guided_params(guided_params=guided_params,
model_config=model_config, model_config=model_config,
...@@ -280,7 +280,7 @@ class GrammarConfig: ...@@ -280,7 +280,7 @@ class GrammarConfig:
class XGrammarLogitsProcessor: class XGrammarLogitsProcessor:
"""Wrapper class to support pickle protocol""" """Wrapper class to support pickle protocol"""
config: GrammarConfig config: GrammarConfig
reasoner: Reasoner | None = None reasoner: ReasoningParser | None = None
ctx: xgr.CompiledGrammar | None = None ctx: xgr.CompiledGrammar | None = None
tokenizer_info: xgr.TokenizerInfo = None # type: ignore[assignment] tokenizer_info: xgr.TokenizerInfo = None # type: ignore[assignment]
......
...@@ -32,6 +32,36 @@ class ReasoningParser: ...@@ -32,6 +32,36 @@ class ReasoningParser:
# whereas all tokenizers have .get_vocab() # whereas all tokenizers have .get_vocab()
return self.model_tokenizer.get_vocab() return self.model_tokenizer.get_vocab()
@abstractmethod
def is_reasoning_end(self, input_ids: list[int]) -> bool:
"""
Check if the reasoning content ends in the input_ids.
It is used in structured engines like `xgrammar` to check if the
reasoning content ends in the model output.
Parameters:
input_ids: list[int]
The input_ids of the model output.
Returns:
bool
True if the reasoning content ends in the input_ids.
"""
@abstractmethod
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
"""
Extract content token ids from the input_ids.
Parameters:
input_ids: list[int]
The input_ids of the model output.
Returns:
list[int]
The extracted content from the input_ids.
"""
@abstractmethod
def extract_reasoning_content( def extract_reasoning_content(
self, model_output: str, request: ChatCompletionRequest self, model_output: str, request: ChatCompletionRequest
) -> tuple[Optional[str], Optional[str]]: ) -> tuple[Optional[str], Optional[str]]:
...@@ -53,10 +83,7 @@ class ReasoningParser: ...@@ -53,10 +83,7 @@ class ReasoningParser:
A tuple containing the reasoning content and the content. A tuple containing the reasoning content and the content.
""" """
raise NotImplementedError( @abstractmethod
"AbstractReasoningParser.extract_reasoning_calls "
"has not been implemented!")
def extract_reasoning_content_streaming( def extract_reasoning_content_streaming(
self, self,
previous_text: str, previous_text: str,
...@@ -73,43 +100,6 @@ class ReasoningParser: ...@@ -73,43 +100,6 @@ class ReasoningParser:
the current tokens/diffs, but also the information about what has the current tokens/diffs, but also the information about what has
previously been parsed and extracted (see constructor) previously been parsed and extracted (see constructor)
""" """
raise NotImplementedError(
"AbstractReasoningParser.extract_reasoning_content_streaming "
"has not been implemented!")
# TODO: need to rebase by PR #14428
@abstractmethod
def is_reasoning_end(self, input_ids: list[int]) -> bool:
"""
Check if the reasoning content ends in the input_ids.
Parameters:
input_ids: list[int]
The input_ids of the model output.
Returns:
bool
True if the reasoning content ends in the input_ids.
"""
raise NotImplementedError(
"AbstractReasoningParser.is_reasoning_end has"
"not been implemented!")
# TODO: need to rebase by PR #14428
@abstractmethod
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
"""
Extract content token ids from the input_ids.
Parameters:
input_ids: list[int]
The input_ids of the model output.
Returns:
list[int]
The extracted content from the input_ids.
"""
raise NotImplementedError(
"AbstractReasoningParser.extract_content_ids has"
" not been implemented!")
class ReasoningParserManager: class ReasoningParserManager:
...@@ -125,14 +115,16 @@ class ReasoningParserManager: ...@@ -125,14 +115,16 @@ class ReasoningParserManager:
if name in cls.reasoning_parsers: if name in cls.reasoning_parsers:
return cls.reasoning_parsers[name] return cls.reasoning_parsers[name]
raise KeyError(f"reasoning helper: '{name}' not found in " raise KeyError(
"reasoning_parsers") f"reasoning helper: '{name}' not found in reasoning_parsers")
@classmethod @classmethod
def _register_module(cls, def _register_module(
cls,
module: type, module: type,
module_name: Optional[Union[str, list[str]]] = None, module_name: Optional[Union[str, list[str]]] = None,
force: bool = True) -> None: force: bool = True,
) -> None:
if not issubclass(module, ReasoningParser): if not issubclass(module, ReasoningParser):
raise TypeError("module must be subclass of ReasoningParser, " raise TypeError("module must be subclass of ReasoningParser, "
f"but got {type(module)}") f"but got {type(module)}")
...@@ -152,7 +144,8 @@ class ReasoningParserManager: ...@@ -152,7 +144,8 @@ class ReasoningParserManager:
cls, cls,
name: Optional[Union[str, list[str]]] = None, name: Optional[Union[str, list[str]]] = None,
force: bool = True, force: bool = True,
module: Union[type, None] = None) -> Union[type, Callable]: module: Union[type, None] = None,
) -> Union[type, Callable]:
""" """
Register module with the given name or name list. it can be used as a Register module with the given name or name list. it can be used as a
decoder(with module as None) or normal function(with module as not decoder(with module as None) or normal function(with module as not
......
...@@ -8,9 +8,8 @@ from transformers import PreTrainedTokenizerBase ...@@ -8,9 +8,8 @@ from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaMessage) DeltaMessage)
from vllm.entrypoints.openai.reasoning_parsers.abs_reasoning_parsers import (
ReasoningParser, ReasoningParserManager)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.reasoning import ReasoningParser, ReasoningParserManager
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -24,39 +23,41 @@ class DeepSeekR1ReasoningParser(ReasoningParser): ...@@ -24,39 +23,41 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
text. This parser extracts the reasoning content from the model output. text. This parser extracts the reasoning content from the model output.
""" """
start_token_id: int
end_token_id: int
start_token: str = "<think>"
end_token: str = "</think>"
def __init__(self, tokenizer: PreTrainedTokenizerBase): def __init__(self, tokenizer: PreTrainedTokenizerBase):
super().__init__(tokenizer) super().__init__(tokenizer)
self.think_start_token = "<think>"
self.think_end_token = "</think>"
self.reasoning_regex = re.compile( self.reasoning_regex = re.compile(
rf"{self.think_start_token}(.*?){self.think_end_token}", re.DOTALL) rf"{self.start_token}(.*?){self.end_token}", re.DOTALL)
if not self.model_tokenizer: if not self.model_tokenizer:
raise ValueError( raise ValueError(
"The model tokenizer must be passed to the ReasoningParser " "The model tokenizer must be passed to the ReasoningParser "
"constructor during construction.") "constructor during construction.")
self.think_start_token_id = self.vocab.get(self.think_start_token) self.start_token_id = self.vocab.get(self.start_token)
self.think_end_token_id = self.vocab.get(self.think_end_token) self.end_token_id = self.vocab.get(self.end_token)
if (self.think_start_token_id is None if self.start_token_id is None or self.end_token_id is None:
or self.think_end_token_id is None):
raise RuntimeError( raise RuntimeError(
"DeepSeek R1 reasoning parser could not locate think start/end " "DeepSeek R1 reasoning parser could not locate think start/end "
"tokens in the tokenizer!") "tokens in the tokenizer!")
# TODO: need to rebase by PR #14428
def is_reasoning_end(self, input_ids: list[int]) -> bool: def is_reasoning_end(self, input_ids: list[int]) -> bool:
return self.think_end_token_id in input_ids return self.end_token_id in input_ids
def extract_content_ids(self, input_ids: list[int]) -> list[int]: def extract_content_ids(self, input_ids: list[int]) -> list[int]:
""" """
Extract the content after the end tokens Extract the content after the end tokens
""" """
if self.think_end_token_id not in input_ids[:-1]: if self.end_token_id not in input_ids[:-1]:
return [] return []
else: else:
return input_ids[input_ids.index(self.think_end_token_id) + 1:] return input_ids[input_ids.index(self.end_token_id) + 1:]
def extract_reasoning_content_streaming( def extract_reasoning_content_streaming(
self, self,
...@@ -77,22 +78,24 @@ class DeepSeekR1ReasoningParser(ReasoningParser): ...@@ -77,22 +78,24 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
""" """
# Skip single special tokens # Skip single special tokens
if len(delta_token_ids) == 1 and (delta_token_ids[0] in [ if len(delta_token_ids) == 1 and (delta_token_ids[0] in [
self.think_start_token_id, self.think_end_token_id self.start_token_id, self.end_token_id
]): ]):
return None return None
# Check if <think> is present in previous or delta. # Check if <think> is present in previous or delta.
# Keep compatibility with models that don't generate <think> tokens. # Keep compatibility with models that don't generate <think> tokens.
if self.think_start_token_id in previous_token_ids: if self.start_token_id in previous_token_ids:
if self.think_end_token_id in delta_token_ids: if self.end_token_id in delta_token_ids:
# <think> in previous, </think> in delta, # <think> in previous, </think> in delta,
# extract reasoning content # extract reasoning content
end_index = delta_text.find(self.think_end_token) end_index = delta_text.find(self.end_token)
reasoning_content = delta_text[:end_index] reasoning_content = delta_text[:end_index]
content = delta_text[end_index + len(self.think_end_token):] content = delta_text[end_index + len(self.end_token):]
return DeltaMessage(reasoning_content=reasoning_content, return DeltaMessage(
content=content if content else None) reasoning_content=reasoning_content,
elif self.think_end_token_id in previous_token_ids: content=content if content else None,
)
elif self.end_token_id in previous_token_ids:
# <think> in previous, </think> in previous, # <think> in previous, </think> in previous,
# reasoning content continues # reasoning content continues
return DeltaMessage(content=delta_text) return DeltaMessage(content=delta_text)
...@@ -100,17 +103,18 @@ class DeepSeekR1ReasoningParser(ReasoningParser): ...@@ -100,17 +103,18 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
# <think> in previous, no </think> in previous or delta, # <think> in previous, no </think> in previous or delta,
# reasoning content continues # reasoning content continues
return DeltaMessage(reasoning_content=delta_text) return DeltaMessage(reasoning_content=delta_text)
elif self.think_start_token_id in delta_token_ids: elif self.start_token_id in delta_token_ids:
if self.think_end_token_id in delta_token_ids: if self.end_token_id in delta_token_ids:
# <think> in delta, </think> in delta, extract reasoning content # <think> in delta, </think> in delta, extract reasoning content
start_index = delta_text.find(self.think_start_token) start_index = delta_text.find(self.start_token)
end_index = delta_text.find(self.think_end_token) end_index = delta_text.find(self.end_token)
reasoning_content = delta_text[start_index + reasoning_content = delta_text[start_index +
len(self.think_start_token len(self.start_token):end_index]
):end_index] content = delta_text[end_index + len(self.end_token):]
content = delta_text[end_index + len(self.think_end_token):] return DeltaMessage(
return DeltaMessage(reasoning_content=reasoning_content, reasoning_content=reasoning_content,
content=content if content else None) content=content if content else None,
)
else: else:
# <think> in delta, no </think> in delta, # <think> in delta, no </think> in delta,
# reasoning content continues # reasoning content continues
...@@ -119,15 +123,17 @@ class DeepSeekR1ReasoningParser(ReasoningParser): ...@@ -119,15 +123,17 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
# No <think> in previous or delta, also need to check for </think>. # No <think> in previous or delta, also need to check for </think>.
# Because the model may have generated </think> without <think> # Because the model may have generated </think> without <think>
# Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f # Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f
if self.think_end_token_id in delta_token_ids: if self.end_token_id in delta_token_ids:
# </think> in delta with more tokens, # </think> in delta with more tokens,
# extract reasoning content and content # extract reasoning content and content
end_index = delta_text.find(self.think_end_token) end_index = delta_text.find(self.end_token)
reasoning_content = delta_text[:end_index] reasoning_content = delta_text[:end_index]
content = delta_text[end_index + len(self.think_end_token):] content = delta_text[end_index + len(self.end_token):]
return DeltaMessage(reasoning_content=reasoning_content, return DeltaMessage(
content=content if content else None) reasoning_content=reasoning_content,
elif self.think_end_token_id in previous_token_ids: content=content if content else None,
)
elif self.end_token_id in previous_token_ids:
# </think> in previous, thinking content ends # </think> in previous, thinking content ends
return DeltaMessage(content=delta_text) return DeltaMessage(content=delta_text)
else: else:
...@@ -137,22 +143,20 @@ class DeepSeekR1ReasoningParser(ReasoningParser): ...@@ -137,22 +143,20 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
def extract_reasoning_content( def extract_reasoning_content(
self, model_output: str, request: ChatCompletionRequest self, model_output: str, request: ChatCompletionRequest
) -> tuple[Optional[str], Optional[str]]: ) -> tuple[Optional[str], Optional[str]]:
# DeepSeek R1 doesn't generate <think> now. # DeepSeek R1 doesn't generate <think> now.
# Thus we assume the reasoning content is always at the start. # Thus we assume the reasoning content is always at the start.
# Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f # Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f
if self.think_end_token not in model_output: if self.end_token not in model_output:
return model_output, None return model_output, None
else: else:
# Add a start token if it's missing to keep compatibility. # Add a start token if it's missing to keep compatibility.
if self.think_start_token not in model_output: if self.start_token not in model_output:
model_output = f"{self.think_start_token}{model_output}" model_output = f"{self.start_token}{model_output}"
# Use a regex to find the reasoning content # Use a regex to find the reasoning content
reasoning_content = self.reasoning_regex.findall(model_output)[0] reasoning_content = self.reasoning_regex.findall(model_output)[0]
end_index = len( end_index = len(
f"{self.think_start_token}{reasoning_content}{self.think_end_token}" f"{self.start_token}{reasoning_content}{self.end_token}")
)
final_output = model_output[end_index:] final_output = model_output[end_index:]
if len(final_output) == 0: if len(final_output) == 0:
......
...@@ -8,9 +8,8 @@ from transformers import PreTrainedTokenizerBase ...@@ -8,9 +8,8 @@ from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaMessage) DeltaMessage)
from vllm.entrypoints.openai.reasoning_parsers.abs_reasoning_parsers import (
ReasoningParser, ReasoningParserManager)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.reasoning import ReasoningParser, ReasoningParserManager
logger = init_logger(__name__) logger = init_logger(__name__)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment