"vllm/vscode:/vscode.git/clone" did not exist on "386e5c47eb5bdb8fd9f159159ee99d0f1e8502f4"
Unverified Commit 34a98427 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Misc] Refactor tokenizer interface (#29693)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent f223ed41
......@@ -22,18 +22,18 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser,
)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.tokenizers import MistralTokenizer, TokenizerLike
logger = init_logger(__name__)
class Hermes2ProToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
if isinstance(self.model_tokenizer, MistralTokenizer):
if isinstance(tokenizer, MistralTokenizer):
logger.error("Detected Mistral tokenizer when using a Hermes model")
self.model_tokenizer = self.model_tokenizer.tokenizer
self.model_tokenizer = tokenizer.tokenizer
self.current_tool_name_sent: bool = False
self.prev_tool_call_arr: list[dict] = []
......
......@@ -22,14 +22,14 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
)
from vllm.entrypoints.openai.tool_parsers.utils import consume_space
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
from vllm.utils import random_uuid
logger = init_logger(__name__)
class HunyuanA13BToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
# Initialize state for streaming mode
......
......@@ -22,13 +22,13 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
)
from vllm.entrypoints.openai.tool_parsers.utils import extract_intermediate_diff
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__)
class Internlm2ToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
self.position = 0
......
......@@ -21,14 +21,13 @@ from vllm.entrypoints.openai.protocol import (
from vllm.entrypoints.openai.tool_parsers import ToolParser
from vllm.entrypoints.openai.tool_parsers.utils import extract_intermediate_diff
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizers import MistralTokenizer
from vllm.tokenizers import MistralTokenizer, TokenizerLike
logger = init_logger(__name__)
class JambaToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
if isinstance(self.model_tokenizer, MistralTokenizer):
......
......@@ -19,13 +19,13 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser,
)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__)
class KimiK2ToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
self.current_tool_name_sent: bool = False
self.prev_tool_call_arr: list[dict] = []
......
......@@ -4,11 +4,11 @@
import regex as re
from vllm.entrypoints.openai.tool_parsers.hermes_tool_parser import Hermes2ProToolParser
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
class LongcatFlashToolParser(Hermes2ProToolParser):
def __init__(self, tokenizer: AnyTokenizer):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
self.tool_call_start_token: str = "<longcat_tool_call>"
......
......@@ -21,13 +21,13 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser,
)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__)
class MinimaxM2ToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
self.prev_tool_call_arr: list[dict] = []
......
......@@ -22,13 +22,13 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
)
from vllm.entrypoints.openai.tool_parsers.utils import extract_intermediate_diff
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__)
class MinimaxToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
# Initialize streaming state for tracking tool call progress
......
......@@ -25,7 +25,7 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
)
from vllm.entrypoints.openai.tool_parsers.utils import extract_intermediate_diff
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.tokenizers import MistralTokenizer, TokenizerLike
logger = init_logger(__name__)
......@@ -46,7 +46,7 @@ class MistralToolCall(ToolCall):
return id.isalnum() and len(id) == 9
def _is_fn_name_regex_support(model_tokenizer: AnyTokenizer) -> bool:
def _is_fn_name_regex_support(model_tokenizer: TokenizerLike) -> bool:
return (
isinstance(model_tokenizer, MistralTokenizer) and model_tokenizer.version >= 11
)
......@@ -61,7 +61,7 @@ class MistralToolParser(ToolParser):
Used when --enable-auto-tool-choice --tool-call-parser mistral are all set
"""
def __init__(self, tokenizer: AnyTokenizer):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
if not isinstance(self.model_tokenizer, MistralTokenizer):
......
......@@ -18,15 +18,15 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
from vllm.logger import init_logger
if TYPE_CHECKING:
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
else:
AnyTokenizer = object
TokenizerLike = object
logger = init_logger(__name__)
class OpenAIToolParser(ToolParser):
def __init__(self, tokenizer: "AnyTokenizer"):
def __init__(self, tokenizer: "TokenizerLike"):
super().__init__(tokenizer)
def extract_tool_calls(
......
......@@ -22,13 +22,13 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser,
)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__)
class Qwen3CoderToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
self.current_tool_name_sent: bool = False
......
......@@ -23,7 +23,7 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser,
)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__)
......@@ -1165,7 +1165,7 @@ class StreamingXMLToolCallParser:
class Qwen3XMLToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
self.parser = StreamingXMLToolCallParser()
......
......@@ -25,7 +25,7 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser,
)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__)
......@@ -34,7 +34,7 @@ class SeedOssToolParser(ToolParser):
TOOL_CALL_START = "<seed:tool_call>"
TOOL_CALL_END = "</seed:tool_call>"
def __init__(self, tokenizer: AnyTokenizer):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
# --- streaming state ---
......
......@@ -21,7 +21,7 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser,
)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
from vllm.utils import random_uuid
logger = init_logger(__name__)
......@@ -41,7 +41,7 @@ class Step3ToolParser(ToolParser):
TOOL_SEP = "<|tool_sep|>"
SPECIAL_TOKENS = [TOOL_CALLS_BEGIN, TOOL_CALLS_END, TOOL_CALL_BEGIN, TOOL_CALL_END]
def __init__(self, tokenizer: AnyTokenizer):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
self.position = 0
# Explicit state flags for robust streaming
......
......@@ -21,14 +21,14 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser,
)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
from vllm.utils import random_uuid
logger = init_logger(__name__)
class xLAMToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
# Initialize state for streaming mode
......
......@@ -16,7 +16,7 @@ from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt
from vllm.inputs.data import TextPrompt as EngineTextPrompt
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
from vllm.inputs.parse import get_prompt_components, parse_raw_prompts
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
from vllm.utils.async_utils import AsyncMicrobatchTokenizer
......@@ -85,7 +85,7 @@ class BaseRenderer(ABC):
def __init__(
self,
model_config: ModelConfig,
tokenizer: AnyTokenizer | None = None,
tokenizer: TokenizerLike | None = None,
):
super().__init__()
self.model_config = model_config
......@@ -200,8 +200,8 @@ class CompletionRenderer(BaseRenderer):
def __init__(
self,
model_config: ModelConfig,
tokenizer: AnyTokenizer | None = None,
async_tokenizer_pool: dict[AnyTokenizer, AsyncMicrobatchTokenizer]
tokenizer: TokenizerLike | None = None,
async_tokenizer_pool: dict[TokenizerLike, AsyncMicrobatchTokenizer]
| None = None,
):
super().__init__(model_config, tokenizer)
......@@ -373,7 +373,7 @@ class CompletionRenderer(BaseRenderer):
return async_tokenizer
tokenizer = self.tokenizer
if self.tokenizer is None:
if tokenizer is None:
raise ValueError("No tokenizer available for text input processing")
if self.async_tokenizer_pool is None:
......
......@@ -19,11 +19,7 @@ from vllm.inputs import TokensPrompt
from vllm.model_executor.models.interfaces import supports_score_template
from vllm.multimodal.inputs import MultiModalDataDict
from vllm.outputs import PoolingRequestOutput
from vllm.transformers_utils.tokenizer import (
AnyTokenizer,
PreTrainedTokenizer,
PreTrainedTokenizerFast,
)
from vllm.transformers_utils.tokenizer import TokenizerLike
ScoreContentPartParam: TypeAlias = (
ChatCompletionContentPartImageParam | ChatCompletionContentPartImageEmbedsParam
......@@ -45,7 +41,7 @@ class ScoreMultiModalParam(TypedDict, total=False):
def _cosine_similarity(
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
tokenizer: TokenizerLike,
embed_1: list[PoolingRequestOutput],
embed_2: list[PoolingRequestOutput],
) -> list[PoolingRequestOutput]:
......@@ -93,7 +89,7 @@ def parse_score_data(
data_1: str | ScoreContentPartParam,
data_2: str | ScoreContentPartParam,
model_config: ModelConfig,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
) -> tuple[str, str, MultiModalDataDict | None]:
mm_tracker = MultiModalItemTracker(model_config, tokenizer)
......@@ -118,12 +114,14 @@ def _parse_score_content(
mm_tracker: BaseMultiModalItemTracker,
) -> _ContentPart | None:
if isinstance(data, str):
data = ChatCompletionContentPartTextParam(type="text", text=data)
part = ChatCompletionContentPartTextParam(type="text", text=data)
else:
part = data
mm_parser = mm_tracker.create_parser()
parse_res = _parse_chat_message_content_part(
data,
part,
mm_parser,
wrap_dicts=False,
interleave_strings=False,
......@@ -181,7 +179,7 @@ def post_process_tokens(
def get_score_prompt(
model_config: ModelConfig,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
tokenization_kwargs: dict[str, Any],
data_1: str | ScoreContentPartParam,
data_2: str | ScoreContentPartParam,
......
......@@ -30,7 +30,7 @@ from vllm.entrypoints.openai.protocol import (
from vllm.entrypoints.openai.serving_models import LoRAModulePath
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.transformers_utils.tokenizers import MistralTokenizer
from vllm.tokenizers import MistralTokenizer
from vllm.utils.argparse_utils import FlexibleArgumentParser
logger = init_logger(__name__)
......
......@@ -17,7 +17,7 @@ from vllm.multimodal.inputs import (
MultiModalUUIDDict,
)
from vllm.multimodal.processing import BaseMultiModalProcessor
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
from vllm.utils.jsontree import json_iter_leaves
from vllm.v1.metrics.stats import MultiModalCacheStats
......@@ -46,7 +46,7 @@ class InputPreprocessor:
def __init__(
self,
model_config: ModelConfig,
tokenizer: AnyTokenizer | None,
tokenizer: TokenizerLike | None,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
mm_processor_cache: BaseMultiModalProcessorCache | None = None,
) -> None:
......@@ -59,7 +59,7 @@ class InputPreprocessor:
self.mm_cache_stats = MultiModalCacheStats() if mm_processor_cache else None
def get_tokenizer(self) -> AnyTokenizer:
def get_tokenizer(self) -> TokenizerLike:
if self.tokenizer is None:
raise ValueError(
"You cannot pass text prompts when `skip_tokenizer_init` is True"
......@@ -228,11 +228,11 @@ class InputPreprocessor:
return tokenizer.encode(prompt, **tokenization_kwargs)
def _get_mm_tokenizer(self) -> AnyTokenizer:
def _get_mm_tokenizer(self) -> TokenizerLike:
# PrithviGeoSpatialMAE needs to be initialized without a tokenizer
# while using also multi-modal input
if not self.tokenizer:
return cast(AnyTokenizer, object()) # Dummy
return cast(TokenizerLike, object()) # Dummy
tokenizer = self.get_tokenizer()
return tokenizer
......
......@@ -5,7 +5,7 @@ from typing import TypeAlias
import torch
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
LogitsProcessor: TypeAlias = (
Callable[[list[int], torch.Tensor], torch.Tensor]
......@@ -19,7 +19,7 @@ to sample from."""
def get_bad_words_logits_processors(
bad_words: list[str], tokenizer: AnyTokenizer
bad_words: list[str], tokenizer: TokenizerLike
) -> list[LogitsProcessor]:
bad_words_ids: list[list[int]] = list()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment