Unverified Commit 34a98427 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Misc] Refactor tokenizer interface (#29693)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent f223ed41
...@@ -22,18 +22,18 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( ...@@ -22,18 +22,18 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParser,
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer from vllm.tokenizers import MistralTokenizer, TokenizerLike
logger = init_logger(__name__) logger = init_logger(__name__)
class Hermes2ProToolParser(ToolParser): class Hermes2ProToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer): def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer) super().__init__(tokenizer)
if isinstance(self.model_tokenizer, MistralTokenizer): if isinstance(tokenizer, MistralTokenizer):
logger.error("Detected Mistral tokenizer when using a Hermes model") logger.error("Detected Mistral tokenizer when using a Hermes model")
self.model_tokenizer = self.model_tokenizer.tokenizer self.model_tokenizer = tokenizer.tokenizer
self.current_tool_name_sent: bool = False self.current_tool_name_sent: bool = False
self.prev_tool_call_arr: list[dict] = [] self.prev_tool_call_arr: list[dict] = []
......
...@@ -22,14 +22,14 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( ...@@ -22,14 +22,14 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
) )
from vllm.entrypoints.openai.tool_parsers.utils import consume_space from vllm.entrypoints.openai.tool_parsers.utils import consume_space
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
from vllm.utils import random_uuid from vllm.utils import random_uuid
logger = init_logger(__name__) logger = init_logger(__name__)
class HunyuanA13BToolParser(ToolParser): class HunyuanA13BToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer): def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer) super().__init__(tokenizer)
# Initialize state for streaming mode # Initialize state for streaming mode
......
...@@ -22,13 +22,13 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( ...@@ -22,13 +22,13 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
) )
from vllm.entrypoints.openai.tool_parsers.utils import extract_intermediate_diff from vllm.entrypoints.openai.tool_parsers.utils import extract_intermediate_diff
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__) logger = init_logger(__name__)
class Internlm2ToolParser(ToolParser): class Internlm2ToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer): def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer) super().__init__(tokenizer)
self.position = 0 self.position = 0
......
...@@ -21,14 +21,13 @@ from vllm.entrypoints.openai.protocol import ( ...@@ -21,14 +21,13 @@ from vllm.entrypoints.openai.protocol import (
from vllm.entrypoints.openai.tool_parsers import ToolParser from vllm.entrypoints.openai.tool_parsers import ToolParser
from vllm.entrypoints.openai.tool_parsers.utils import extract_intermediate_diff from vllm.entrypoints.openai.tool_parsers.utils import extract_intermediate_diff
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import MistralTokenizer, TokenizerLike
from vllm.transformers_utils.tokenizers import MistralTokenizer
logger = init_logger(__name__) logger = init_logger(__name__)
class JambaToolParser(ToolParser): class JambaToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer): def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer) super().__init__(tokenizer)
if isinstance(self.model_tokenizer, MistralTokenizer): if isinstance(self.model_tokenizer, MistralTokenizer):
......
...@@ -19,13 +19,13 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( ...@@ -19,13 +19,13 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParser,
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__) logger = init_logger(__name__)
class KimiK2ToolParser(ToolParser): class KimiK2ToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer): def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer) super().__init__(tokenizer)
self.current_tool_name_sent: bool = False self.current_tool_name_sent: bool = False
self.prev_tool_call_arr: list[dict] = [] self.prev_tool_call_arr: list[dict] = []
......
...@@ -4,11 +4,11 @@ ...@@ -4,11 +4,11 @@
import regex as re import regex as re
from vllm.entrypoints.openai.tool_parsers.hermes_tool_parser import Hermes2ProToolParser from vllm.entrypoints.openai.tool_parsers.hermes_tool_parser import Hermes2ProToolParser
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
class LongcatFlashToolParser(Hermes2ProToolParser): class LongcatFlashToolParser(Hermes2ProToolParser):
def __init__(self, tokenizer: AnyTokenizer): def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer) super().__init__(tokenizer)
self.tool_call_start_token: str = "<longcat_tool_call>" self.tool_call_start_token: str = "<longcat_tool_call>"
......
...@@ -21,13 +21,13 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( ...@@ -21,13 +21,13 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParser,
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__) logger = init_logger(__name__)
class MinimaxM2ToolParser(ToolParser): class MinimaxM2ToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer): def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer) super().__init__(tokenizer)
self.prev_tool_call_arr: list[dict] = [] self.prev_tool_call_arr: list[dict] = []
......
...@@ -22,13 +22,13 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( ...@@ -22,13 +22,13 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
) )
from vllm.entrypoints.openai.tool_parsers.utils import extract_intermediate_diff from vllm.entrypoints.openai.tool_parsers.utils import extract_intermediate_diff
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__) logger = init_logger(__name__)
class MinimaxToolParser(ToolParser): class MinimaxToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer): def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer) super().__init__(tokenizer)
# Initialize streaming state for tracking tool call progress # Initialize streaming state for tracking tool call progress
......
...@@ -25,7 +25,7 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( ...@@ -25,7 +25,7 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
) )
from vllm.entrypoints.openai.tool_parsers.utils import extract_intermediate_diff from vllm.entrypoints.openai.tool_parsers.utils import extract_intermediate_diff
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer from vllm.tokenizers import MistralTokenizer, TokenizerLike
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -46,7 +46,7 @@ class MistralToolCall(ToolCall): ...@@ -46,7 +46,7 @@ class MistralToolCall(ToolCall):
return id.isalnum() and len(id) == 9 return id.isalnum() and len(id) == 9
def _is_fn_name_regex_support(model_tokenizer: AnyTokenizer) -> bool: def _is_fn_name_regex_support(model_tokenizer: TokenizerLike) -> bool:
return ( return (
isinstance(model_tokenizer, MistralTokenizer) and model_tokenizer.version >= 11 isinstance(model_tokenizer, MistralTokenizer) and model_tokenizer.version >= 11
) )
...@@ -61,7 +61,7 @@ class MistralToolParser(ToolParser): ...@@ -61,7 +61,7 @@ class MistralToolParser(ToolParser):
Used when --enable-auto-tool-choice --tool-call-parser mistral are all set Used when --enable-auto-tool-choice --tool-call-parser mistral are all set
""" """
def __init__(self, tokenizer: AnyTokenizer): def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer) super().__init__(tokenizer)
if not isinstance(self.model_tokenizer, MistralTokenizer): if not isinstance(self.model_tokenizer, MistralTokenizer):
......
...@@ -18,15 +18,15 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( ...@@ -18,15 +18,15 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
from vllm.logger import init_logger from vllm.logger import init_logger
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
else: else:
AnyTokenizer = object TokenizerLike = object
logger = init_logger(__name__) logger = init_logger(__name__)
class OpenAIToolParser(ToolParser): class OpenAIToolParser(ToolParser):
def __init__(self, tokenizer: "AnyTokenizer"): def __init__(self, tokenizer: "TokenizerLike"):
super().__init__(tokenizer) super().__init__(tokenizer)
def extract_tool_calls( def extract_tool_calls(
......
...@@ -22,13 +22,13 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( ...@@ -22,13 +22,13 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParser,
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__) logger = init_logger(__name__)
class Qwen3CoderToolParser(ToolParser): class Qwen3CoderToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer): def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer) super().__init__(tokenizer)
self.current_tool_name_sent: bool = False self.current_tool_name_sent: bool = False
......
...@@ -23,7 +23,7 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( ...@@ -23,7 +23,7 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParser,
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -1165,7 +1165,7 @@ class StreamingXMLToolCallParser: ...@@ -1165,7 +1165,7 @@ class StreamingXMLToolCallParser:
class Qwen3XMLToolParser(ToolParser): class Qwen3XMLToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer): def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer) super().__init__(tokenizer)
self.parser = StreamingXMLToolCallParser() self.parser = StreamingXMLToolCallParser()
......
...@@ -25,7 +25,7 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( ...@@ -25,7 +25,7 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParser,
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -34,7 +34,7 @@ class SeedOssToolParser(ToolParser): ...@@ -34,7 +34,7 @@ class SeedOssToolParser(ToolParser):
TOOL_CALL_START = "<seed:tool_call>" TOOL_CALL_START = "<seed:tool_call>"
TOOL_CALL_END = "</seed:tool_call>" TOOL_CALL_END = "</seed:tool_call>"
def __init__(self, tokenizer: AnyTokenizer): def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer) super().__init__(tokenizer)
# --- streaming state --- # --- streaming state ---
......
...@@ -21,7 +21,7 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( ...@@ -21,7 +21,7 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParser,
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
from vllm.utils import random_uuid from vllm.utils import random_uuid
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -41,7 +41,7 @@ class Step3ToolParser(ToolParser): ...@@ -41,7 +41,7 @@ class Step3ToolParser(ToolParser):
TOOL_SEP = "<|tool_sep|>" TOOL_SEP = "<|tool_sep|>"
SPECIAL_TOKENS = [TOOL_CALLS_BEGIN, TOOL_CALLS_END, TOOL_CALL_BEGIN, TOOL_CALL_END] SPECIAL_TOKENS = [TOOL_CALLS_BEGIN, TOOL_CALLS_END, TOOL_CALL_BEGIN, TOOL_CALL_END]
def __init__(self, tokenizer: AnyTokenizer): def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer) super().__init__(tokenizer)
self.position = 0 self.position = 0
# Explicit state flags for robust streaming # Explicit state flags for robust streaming
......
...@@ -21,14 +21,14 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( ...@@ -21,14 +21,14 @@ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParser,
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
from vllm.utils import random_uuid from vllm.utils import random_uuid
logger = init_logger(__name__) logger = init_logger(__name__)
class xLAMToolParser(ToolParser): class xLAMToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer): def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer) super().__init__(tokenizer)
# Initialize state for streaming mode # Initialize state for streaming mode
......
...@@ -16,7 +16,7 @@ from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt ...@@ -16,7 +16,7 @@ from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt
from vllm.inputs.data import TextPrompt as EngineTextPrompt from vllm.inputs.data import TextPrompt as EngineTextPrompt
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
from vllm.inputs.parse import get_prompt_components, parse_raw_prompts from vllm.inputs.parse import get_prompt_components, parse_raw_prompts
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
from vllm.utils.async_utils import AsyncMicrobatchTokenizer from vllm.utils.async_utils import AsyncMicrobatchTokenizer
...@@ -85,7 +85,7 @@ class BaseRenderer(ABC): ...@@ -85,7 +85,7 @@ class BaseRenderer(ABC):
def __init__( def __init__(
self, self,
model_config: ModelConfig, model_config: ModelConfig,
tokenizer: AnyTokenizer | None = None, tokenizer: TokenizerLike | None = None,
): ):
super().__init__() super().__init__()
self.model_config = model_config self.model_config = model_config
...@@ -200,8 +200,8 @@ class CompletionRenderer(BaseRenderer): ...@@ -200,8 +200,8 @@ class CompletionRenderer(BaseRenderer):
def __init__( def __init__(
self, self,
model_config: ModelConfig, model_config: ModelConfig,
tokenizer: AnyTokenizer | None = None, tokenizer: TokenizerLike | None = None,
async_tokenizer_pool: dict[AnyTokenizer, AsyncMicrobatchTokenizer] async_tokenizer_pool: dict[TokenizerLike, AsyncMicrobatchTokenizer]
| None = None, | None = None,
): ):
super().__init__(model_config, tokenizer) super().__init__(model_config, tokenizer)
...@@ -373,7 +373,7 @@ class CompletionRenderer(BaseRenderer): ...@@ -373,7 +373,7 @@ class CompletionRenderer(BaseRenderer):
return async_tokenizer return async_tokenizer
tokenizer = self.tokenizer tokenizer = self.tokenizer
if self.tokenizer is None: if tokenizer is None:
raise ValueError("No tokenizer available for text input processing") raise ValueError("No tokenizer available for text input processing")
if self.async_tokenizer_pool is None: if self.async_tokenizer_pool is None:
......
...@@ -19,11 +19,7 @@ from vllm.inputs import TokensPrompt ...@@ -19,11 +19,7 @@ from vllm.inputs import TokensPrompt
from vllm.model_executor.models.interfaces import supports_score_template from vllm.model_executor.models.interfaces import supports_score_template
from vllm.multimodal.inputs import MultiModalDataDict from vllm.multimodal.inputs import MultiModalDataDict
from vllm.outputs import PoolingRequestOutput from vllm.outputs import PoolingRequestOutput
from vllm.transformers_utils.tokenizer import ( from vllm.transformers_utils.tokenizer import TokenizerLike
AnyTokenizer,
PreTrainedTokenizer,
PreTrainedTokenizerFast,
)
ScoreContentPartParam: TypeAlias = ( ScoreContentPartParam: TypeAlias = (
ChatCompletionContentPartImageParam | ChatCompletionContentPartImageEmbedsParam ChatCompletionContentPartImageParam | ChatCompletionContentPartImageEmbedsParam
...@@ -45,7 +41,7 @@ class ScoreMultiModalParam(TypedDict, total=False): ...@@ -45,7 +41,7 @@ class ScoreMultiModalParam(TypedDict, total=False):
def _cosine_similarity( def _cosine_similarity(
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, tokenizer: TokenizerLike,
embed_1: list[PoolingRequestOutput], embed_1: list[PoolingRequestOutput],
embed_2: list[PoolingRequestOutput], embed_2: list[PoolingRequestOutput],
) -> list[PoolingRequestOutput]: ) -> list[PoolingRequestOutput]:
...@@ -93,7 +89,7 @@ def parse_score_data( ...@@ -93,7 +89,7 @@ def parse_score_data(
data_1: str | ScoreContentPartParam, data_1: str | ScoreContentPartParam,
data_2: str | ScoreContentPartParam, data_2: str | ScoreContentPartParam,
model_config: ModelConfig, model_config: ModelConfig,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
) -> tuple[str, str, MultiModalDataDict | None]: ) -> tuple[str, str, MultiModalDataDict | None]:
mm_tracker = MultiModalItemTracker(model_config, tokenizer) mm_tracker = MultiModalItemTracker(model_config, tokenizer)
...@@ -118,12 +114,14 @@ def _parse_score_content( ...@@ -118,12 +114,14 @@ def _parse_score_content(
mm_tracker: BaseMultiModalItemTracker, mm_tracker: BaseMultiModalItemTracker,
) -> _ContentPart | None: ) -> _ContentPart | None:
if isinstance(data, str): if isinstance(data, str):
data = ChatCompletionContentPartTextParam(type="text", text=data) part = ChatCompletionContentPartTextParam(type="text", text=data)
else:
part = data
mm_parser = mm_tracker.create_parser() mm_parser = mm_tracker.create_parser()
parse_res = _parse_chat_message_content_part( parse_res = _parse_chat_message_content_part(
data, part,
mm_parser, mm_parser,
wrap_dicts=False, wrap_dicts=False,
interleave_strings=False, interleave_strings=False,
...@@ -181,7 +179,7 @@ def post_process_tokens( ...@@ -181,7 +179,7 @@ def post_process_tokens(
def get_score_prompt( def get_score_prompt(
model_config: ModelConfig, model_config: ModelConfig,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
tokenization_kwargs: dict[str, Any], tokenization_kwargs: dict[str, Any],
data_1: str | ScoreContentPartParam, data_1: str | ScoreContentPartParam,
data_2: str | ScoreContentPartParam, data_2: str | ScoreContentPartParam,
......
...@@ -30,7 +30,7 @@ from vllm.entrypoints.openai.protocol import ( ...@@ -30,7 +30,7 @@ from vllm.entrypoints.openai.protocol import (
from vllm.entrypoints.openai.serving_models import LoRAModulePath from vllm.entrypoints.openai.serving_models import LoRAModulePath
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.transformers_utils.tokenizers import MistralTokenizer from vllm.tokenizers import MistralTokenizer
from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
logger = init_logger(__name__) logger = init_logger(__name__)
......
...@@ -17,7 +17,7 @@ from vllm.multimodal.inputs import ( ...@@ -17,7 +17,7 @@ from vllm.multimodal.inputs import (
MultiModalUUIDDict, MultiModalUUIDDict,
) )
from vllm.multimodal.processing import BaseMultiModalProcessor from vllm.multimodal.processing import BaseMultiModalProcessor
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
from vllm.utils.jsontree import json_iter_leaves from vllm.utils.jsontree import json_iter_leaves
from vllm.v1.metrics.stats import MultiModalCacheStats from vllm.v1.metrics.stats import MultiModalCacheStats
...@@ -46,7 +46,7 @@ class InputPreprocessor: ...@@ -46,7 +46,7 @@ class InputPreprocessor:
def __init__( def __init__(
self, self,
model_config: ModelConfig, model_config: ModelConfig,
tokenizer: AnyTokenizer | None, tokenizer: TokenizerLike | None,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
mm_processor_cache: BaseMultiModalProcessorCache | None = None, mm_processor_cache: BaseMultiModalProcessorCache | None = None,
) -> None: ) -> None:
...@@ -59,7 +59,7 @@ class InputPreprocessor: ...@@ -59,7 +59,7 @@ class InputPreprocessor:
self.mm_cache_stats = MultiModalCacheStats() if mm_processor_cache else None self.mm_cache_stats = MultiModalCacheStats() if mm_processor_cache else None
def get_tokenizer(self) -> AnyTokenizer: def get_tokenizer(self) -> TokenizerLike:
if self.tokenizer is None: if self.tokenizer is None:
raise ValueError( raise ValueError(
"You cannot pass text prompts when `skip_tokenizer_init` is True" "You cannot pass text prompts when `skip_tokenizer_init` is True"
...@@ -228,11 +228,11 @@ class InputPreprocessor: ...@@ -228,11 +228,11 @@ class InputPreprocessor:
return tokenizer.encode(prompt, **tokenization_kwargs) return tokenizer.encode(prompt, **tokenization_kwargs)
def _get_mm_tokenizer(self) -> AnyTokenizer: def _get_mm_tokenizer(self) -> TokenizerLike:
# PrithviGeoSpatialMAE needs to be initialized without a tokenizer # PrithviGeoSpatialMAE needs to be initialized without a tokenizer
# while using also multi-modal input # while using also multi-modal input
if not self.tokenizer: if not self.tokenizer:
return cast(AnyTokenizer, object()) # Dummy return cast(TokenizerLike, object()) # Dummy
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()
return tokenizer return tokenizer
......
...@@ -5,7 +5,7 @@ from typing import TypeAlias ...@@ -5,7 +5,7 @@ from typing import TypeAlias
import torch import torch
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
LogitsProcessor: TypeAlias = ( LogitsProcessor: TypeAlias = (
Callable[[list[int], torch.Tensor], torch.Tensor] Callable[[list[int], torch.Tensor], torch.Tensor]
...@@ -19,7 +19,7 @@ to sample from.""" ...@@ -19,7 +19,7 @@ to sample from."""
def get_bad_words_logits_processors( def get_bad_words_logits_processors(
bad_words: list[str], tokenizer: AnyTokenizer bad_words: list[str], tokenizer: TokenizerLike
) -> list[LogitsProcessor]: ) -> list[LogitsProcessor]:
bad_words_ids: list[list[int]] = list() bad_words_ids: list[list[int]] = list()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment