Unverified Commit 34a98427 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Misc] Refactor tokenizer interface (#29693)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent f223ed41
...@@ -28,7 +28,7 @@ from vllm.multimodal.processing import ( ...@@ -28,7 +28,7 @@ from vllm.multimodal.processing import (
PromptUpdate, PromptUpdate,
PromptUpdateDetails, PromptUpdateDetails,
) )
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
from .intern_vit import InternVisionModel from .intern_vit import InternVisionModel
from .internvl import ( from .internvl import (
...@@ -241,7 +241,7 @@ class H2OVLProcessor(BaseInternVLProcessor): ...@@ -241,7 +241,7 @@ class H2OVLProcessor(BaseInternVLProcessor):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
*, *,
min_dynamic_patch: int | None = None, min_dynamic_patch: int | None = None,
max_dynamic_patch: int | None = None, max_dynamic_patch: int | None = None,
......
...@@ -50,7 +50,7 @@ from vllm.multimodal.processing import ( ...@@ -50,7 +50,7 @@ from vllm.multimodal.processing import (
) )
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
from vllm.utils.torch_utils import set_default_torch_num_threads from vllm.utils.torch_utils import set_default_torch_num_threads
...@@ -347,7 +347,7 @@ class BaseInternVLProcessor(ABC): ...@@ -347,7 +347,7 @@ class BaseInternVLProcessor(ABC):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
*, *,
min_dynamic_patch: int | None = None, min_dynamic_patch: int | None = None,
max_dynamic_patch: int | None = None, max_dynamic_patch: int | None = None,
...@@ -561,7 +561,7 @@ class InternVLProcessor(BaseInternVLProcessor): ...@@ -561,7 +561,7 @@ class InternVLProcessor(BaseInternVLProcessor):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
*, *,
min_dynamic_patch: int | None = None, min_dynamic_patch: int | None = None,
max_dynamic_patch: int | None = None, max_dynamic_patch: int | None = None,
......
...@@ -73,9 +73,9 @@ from vllm.multimodal.processing import ( ...@@ -73,9 +73,9 @@ from vllm.multimodal.processing import (
) )
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.tokenizers import TokenizerLike
from vllm.transformers_utils.configs.radio import RadioConfig from vllm.transformers_utils.configs.radio import RadioConfig
from vllm.transformers_utils.tokenizer import ( from vllm.transformers_utils.tokenizer import (
AnyTokenizer,
cached_tokenizer_from_config, cached_tokenizer_from_config,
encode_tokens, encode_tokens,
) )
...@@ -284,7 +284,7 @@ class BaseNanoNemotronVLProcessor(ABC): ...@@ -284,7 +284,7 @@ class BaseNanoNemotronVLProcessor(ABC):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
*args, *args,
max_num_tiles: int | None = None, max_num_tiles: int | None = None,
**kwargs, **kwargs,
...@@ -434,7 +434,7 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor): ...@@ -434,7 +434,7 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
*, *,
max_num_tiles: int | None = None, max_num_tiles: int | None = None,
min_dynamic_patch: int | None = None, min_dynamic_patch: int | None = None,
...@@ -645,7 +645,7 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor): ...@@ -645,7 +645,7 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
tokens_per_frame: list[int], tokens_per_frame: list[int],
frames_indices: list[int], frames_indices: list[int],
frame_duration_ms: int, frame_duration_ms: int,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
img_start_token_ids: list[int], img_start_token_ids: list[int],
img_end_token_ids: list[int], img_end_token_ids: list[int],
img_context_token_ids: list[int], img_context_token_ids: list[int],
...@@ -670,7 +670,7 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor): ...@@ -670,7 +670,7 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
tokens_per_frame (list[int]): number of tokens per frame tokens_per_frame (list[int]): number of tokens per frame
frames_indices (list[int]): frame indices frames_indices (list[int]): frame indices
frame_duration_ms (int): duration of each frame in milliseconds frame_duration_ms (int): duration of each frame in milliseconds
tokenizer (AnyTokenizer): tokenizer to use for tokenizing frame separators tokenizer (TokenizerLike): tokenizer to use for tokenizing frame separators
img_start_token_ids (list[int]): pre-tokenized IMG_START tokens img_start_token_ids (list[int]): pre-tokenized IMG_START tokens
img_end_token_ids (list[int]): pre-tokenized IMG_END tokens img_end_token_ids (list[int]): pre-tokenized IMG_END tokens
img_context_token_ids (list[int]): pre-tokenized IMG_CONTEXT tokens img_context_token_ids (list[int]): pre-tokenized IMG_CONTEXT tokens
......
...@@ -34,8 +34,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY ...@@ -34,8 +34,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import convert_image_mode from vllm.multimodal.image import convert_image_mode
from vllm.multimodal.processing import PromptUpdateDetails from vllm.multimodal.processing import PromptUpdateDetails
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.tokenizers import TokenizerLike
from vllm.transformers_utils.processor import cached_image_processor_from_config from vllm.transformers_utils.processor import cached_image_processor_from_config
from vllm.transformers_utils.tokenizer import AnyTokenizer
from .interfaces import ( from .interfaces import (
MultiModalEmbeddings, MultiModalEmbeddings,
...@@ -203,7 +203,7 @@ class NemotronVLProcessor(InternVLProcessor): ...@@ -203,7 +203,7 @@ class NemotronVLProcessor(InternVLProcessor):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
image_processor: BaseImageProcessorFast, image_processor: BaseImageProcessorFast,
*, *,
min_dynamic_patch: int | None = None, min_dynamic_patch: int | None = None,
......
...@@ -31,7 +31,7 @@ from vllm.multimodal.processing import ( ...@@ -31,7 +31,7 @@ from vllm.multimodal.processing import (
PromptReplacement, PromptReplacement,
PromptUpdate, PromptUpdate,
) )
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
from .qwen2_5_vl import ( from .qwen2_5_vl import (
Qwen2_5_VisionTransformer as OpenCUAVisionTransformer, Qwen2_5_VisionTransformer as OpenCUAVisionTransformer,
...@@ -79,7 +79,7 @@ class OpenCUAProcessor(Qwen2VLProcessor): ...@@ -79,7 +79,7 @@ class OpenCUAProcessor(Qwen2VLProcessor):
def __init__( def __init__(
self, self,
vision_config: dict, vision_config: dict,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
**kwargs, **kwargs,
): ):
image_processor = Qwen2VLImageProcessor(**vision_config) image_processor = Qwen2VLImageProcessor(**vision_config)
......
...@@ -59,10 +59,8 @@ from vllm.multimodal.processing import ( ...@@ -59,10 +59,8 @@ from vllm.multimodal.processing import (
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import ( from vllm.tokenizers import MistralTokenizer
MistralTokenizer, from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
cached_tokenizer_from_config,
)
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
......
...@@ -91,7 +91,7 @@ from vllm.multimodal.processing import ( ...@@ -91,7 +91,7 @@ from vllm.multimodal.processing import (
) )
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import ( from .interfaces import (
...@@ -1533,7 +1533,7 @@ class Tarsier2Processor(Qwen2VLProcessor): ...@@ -1533,7 +1533,7 @@ class Tarsier2Processor(Qwen2VLProcessor):
def __init__( def __init__(
self, self,
vision_config: dict, vision_config: dict,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
**kwargs, **kwargs,
): ):
self.image_processor = Tarsier2ImageProcessor(**vision_config) self.image_processor = Tarsier2ImageProcessor(**vision_config)
......
...@@ -47,7 +47,7 @@ from vllm.multimodal.processing import ( ...@@ -47,7 +47,7 @@ from vllm.multimodal.processing import (
) )
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
...@@ -282,7 +282,7 @@ class SkyworkR1VProcessor: ...@@ -282,7 +282,7 @@ class SkyworkR1VProcessor:
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
*, *,
min_dynamic_patch: int | None = None, min_dynamic_patch: int | None = None,
max_dynamic_patch: int | None = None, max_dynamic_patch: int | None = None,
......
...@@ -43,8 +43,8 @@ from vllm.multimodal.processing import ( ...@@ -43,8 +43,8 @@ from vllm.multimodal.processing import (
) )
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.tokenizers import TokenizerLike
from vllm.transformers_utils.configs import Step3VisionEncoderConfig from vllm.transformers_utils.configs import Step3VisionEncoderConfig
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
...@@ -321,7 +321,7 @@ class Step3VLProcessor: ...@@ -321,7 +321,7 @@ class Step3VLProcessor:
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
) -> None: ) -> None:
super().__init__() super().__init__()
......
...@@ -51,10 +51,8 @@ from vllm.multimodal.processing import ( ...@@ -51,10 +51,8 @@ from vllm.multimodal.processing import (
) )
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import ( from vllm.tokenizers import MistralTokenizer
MistralTokenizer, from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
cached_tokenizer_from_config,
)
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsTranscription from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsTranscription
from .utils import init_vllm_registered_model, maybe_prefix from .utils import init_vllm_registered_model, maybe_prefix
......
...@@ -23,8 +23,9 @@ import torch ...@@ -23,8 +23,9 @@ import torch
from typing_extensions import TypeVar, assert_never from typing_extensions import TypeVar, assert_never
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.transformers_utils.processor import cached_processor_from_config from vllm.transformers_utils.processor import cached_processor_from_config
from vllm.transformers_utils.tokenizer import AnyTokenizer, decode_tokens, encode_tokens from vllm.transformers_utils.tokenizer import decode_tokens, encode_tokens
from vllm.utils.collection_utils import flatten_2d_lists, full_groupby from vllm.utils.collection_utils import flatten_2d_lists, full_groupby
from vllm.utils.func_utils import get_allowed_kwarg_only_overrides from vllm.utils.func_utils import get_allowed_kwarg_only_overrides
from vllm.utils.jsontree import JSONTree, json_map_leaves from vllm.utils.jsontree import JSONTree, json_map_leaves
...@@ -76,7 +77,7 @@ PromptSeq: TypeAlias = str | list[int] ...@@ -76,7 +77,7 @@ PromptSeq: TypeAlias = str | list[int]
@lru_cache(maxsize=2048) @lru_cache(maxsize=2048)
def _cached_encode( def _cached_encode(
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
text: str, text: str,
*, *,
add_special_tokens: bool | None = None, add_special_tokens: bool | None = None,
...@@ -86,7 +87,7 @@ def _cached_encode( ...@@ -86,7 +87,7 @@ def _cached_encode(
@lru_cache(maxsize=2048) @lru_cache(maxsize=2048)
def _cached_decode( def _cached_decode(
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
token_ids: tuple[int, ...], token_ids: tuple[int, ...],
*, *,
skip_special_tokens: bool | None = None, skip_special_tokens: bool | None = None,
...@@ -96,14 +97,14 @@ def _cached_decode( ...@@ -96,14 +97,14 @@ def _cached_decode(
) )
def _seq2text(tokenizer: AnyTokenizer, seq: PromptSeq) -> str: def _seq2text(tokenizer: TokenizerLike, seq: PromptSeq) -> str:
if isinstance(seq, str): if isinstance(seq, str):
return seq return seq
return _cached_decode(tokenizer, tuple(seq)) return _cached_decode(tokenizer, tuple(seq))
def _seq2tokens(tokenizer: AnyTokenizer, seq: PromptSeq) -> list[int]: def _seq2tokens(tokenizer: TokenizerLike, seq: PromptSeq) -> list[int]:
if isinstance(seq, str): if isinstance(seq, str):
return _cached_encode(tokenizer, seq, add_special_tokens=False) return _cached_encode(tokenizer, seq, add_special_tokens=False)
...@@ -113,7 +114,7 @@ def _seq2tokens(tokenizer: AnyTokenizer, seq: PromptSeq) -> list[int]: ...@@ -113,7 +114,7 @@ def _seq2tokens(tokenizer: AnyTokenizer, seq: PromptSeq) -> list[int]:
class _GetMatchIndex(Protocol): class _GetMatchIndex(Protocol):
def __call__( def __call__(
self, self,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
prompt: PromptSeq, prompt: PromptSeq,
start_idx: int = 0, start_idx: int = 0,
) -> int | None: ... ) -> int | None: ...
...@@ -143,7 +144,7 @@ class PromptIndexTargets: ...@@ -143,7 +144,7 @@ class PromptIndexTargets:
""" """
def get_match_index( def get_match_index(
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
prompt: PromptSeq, prompt: PromptSeq,
start_idx: int = 0, start_idx: int = 0,
) -> int | None: ) -> int | None:
...@@ -199,7 +200,7 @@ class PromptUpdateDetails(Generic[_S]): ...@@ -199,7 +200,7 @@ class PromptUpdateDetails(Generic[_S]):
full: _S full: _S
"""The full content.""" """The full content."""
is_embed: Callable[[AnyTokenizer, PromptSeq], torch.Tensor] | None = None is_embed: Callable[[TokenizerLike, PromptSeq], torch.Tensor] | None = None
""" """
Given [`full`][vllm.multimodal.processing.PromptUpdateDetails.full], Given [`full`][vllm.multimodal.processing.PromptUpdateDetails.full],
return a boolean mask of shape `(len(full),)` indicating which positions return a boolean mask of shape `(len(full),)` indicating which positions
...@@ -220,7 +221,7 @@ class PromptUpdateDetails(Generic[_S]): ...@@ -220,7 +221,7 @@ class PromptUpdateDetails(Generic[_S]):
seq: _S, seq: _S,
embed_text: str, embed_text: str,
) -> "PromptUpdateDetails[_S]": ) -> "PromptUpdateDetails[_S]":
def is_embed(tokenizer: AnyTokenizer, full: PromptSeq) -> torch.Tensor: def is_embed(tokenizer: TokenizerLike, full: PromptSeq) -> torch.Tensor:
embed_token_ids = encode_tokens(tokenizer, embed_text) embed_token_ids = encode_tokens(tokenizer, embed_text)
token_ids = _seq2tokens(tokenizer, full) token_ids = _seq2tokens(tokenizer, full)
...@@ -236,7 +237,7 @@ class PromptUpdateDetails(Generic[_S]): ...@@ -236,7 +237,7 @@ class PromptUpdateDetails(Generic[_S]):
seq: _S, seq: _S,
embed_token_id: int, embed_token_id: int,
) -> "PromptUpdateDetails[_S]": ) -> "PromptUpdateDetails[_S]":
def is_embed(tokenizer: AnyTokenizer, full: PromptSeq) -> torch.Tensor: def is_embed(tokenizer: TokenizerLike, full: PromptSeq) -> torch.Tensor:
token_ids = _seq2tokens(tokenizer, full) token_ids = _seq2tokens(tokenizer, full)
return torch.tensor(token_ids) == embed_token_id return torch.tensor(token_ids) == embed_token_id
...@@ -522,7 +523,7 @@ class ResolvedPromptUpdate: ...@@ -522,7 +523,7 @@ class ResolvedPromptUpdate:
def iter_token_matches( def iter_token_matches(
self, self,
prompt: list[int], prompt: list[int],
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
*, *,
start_idx: int = 0, start_idx: int = 0,
) -> Generator[PromptTargetMatch]: ) -> Generator[PromptTargetMatch]:
...@@ -544,7 +545,7 @@ class ResolvedPromptUpdate: ...@@ -544,7 +545,7 @@ class ResolvedPromptUpdate:
def iter_text_matches( def iter_text_matches(
self, self,
prompt: str, prompt: str,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
*, *,
start_idx: int = 0, start_idx: int = 0,
) -> Generator[PromptTargetMatch]: ) -> Generator[PromptTargetMatch]:
...@@ -566,7 +567,7 @@ class ResolvedPromptUpdate: ...@@ -566,7 +567,7 @@ class ResolvedPromptUpdate:
def iter_matches( def iter_matches(
self, self,
prompt: list[int] | str, prompt: list[int] | str,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
*, *,
start_idx: int = 0, start_idx: int = 0,
) -> Generator[PromptTargetMatch]: ) -> Generator[PromptTargetMatch]:
...@@ -675,7 +676,7 @@ _MatchToApply = tuple[tuple[str, int], tuple[PromptTargetMatch, int]] ...@@ -675,7 +676,7 @@ _MatchToApply = tuple[tuple[str, int], tuple[PromptTargetMatch, int]]
def _find_matches( def _find_matches(
prompt: _S, prompt: _S,
mm_prompt_updates: "MultiModalPromptUpdates", mm_prompt_updates: "MultiModalPromptUpdates",
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
*, *,
prev_end_idx: int = 0, prev_end_idx: int = 0,
current_result: "MultiModalPromptUpdatesApplyResult", current_result: "MultiModalPromptUpdatesApplyResult",
...@@ -740,7 +741,7 @@ def _all_items_found( ...@@ -740,7 +741,7 @@ def _all_items_found(
def _apply_matches( def _apply_matches(
prompt: _S, prompt: _S,
mm_prompt_updates: "MultiModalPromptUpdates", mm_prompt_updates: "MultiModalPromptUpdates",
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
) -> tuple[list[_S], "MultiModalPromptUpdatesApplyResult"]: ) -> tuple[list[_S], "MultiModalPromptUpdatesApplyResult"]:
mm_item_counts = {m: len(items) for m, items in mm_prompt_updates.items()} mm_item_counts = {m: len(items) for m, items in mm_prompt_updates.items()}
...@@ -806,7 +807,7 @@ def _apply_matches( ...@@ -806,7 +807,7 @@ def _apply_matches(
def apply_token_matches( def apply_token_matches(
prompt: list[int], prompt: list[int],
mm_prompt_updates: "MultiModalPromptUpdates", mm_prompt_updates: "MultiModalPromptUpdates",
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
) -> tuple[list[int], "MultiModalPromptUpdatesApplyResult"]: ) -> tuple[list[int], "MultiModalPromptUpdatesApplyResult"]:
""" """
Apply the updates in `mm_prompt_updates` to `prompt`. Apply the updates in `mm_prompt_updates` to `prompt`.
...@@ -823,7 +824,7 @@ def apply_token_matches( ...@@ -823,7 +824,7 @@ def apply_token_matches(
def apply_text_matches( def apply_text_matches(
prompt: str, prompt: str,
mm_prompt_updates: "MultiModalPromptUpdates", mm_prompt_updates: "MultiModalPromptUpdates",
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
) -> tuple[str, "MultiModalPromptUpdatesApplyResult"]: ) -> tuple[str, "MultiModalPromptUpdatesApplyResult"]:
""" """
Apply the updates in `mm_prompt_updates` to `prompt`. Apply the updates in `mm_prompt_updates` to `prompt`.
...@@ -840,7 +841,7 @@ def apply_text_matches( ...@@ -840,7 +841,7 @@ def apply_text_matches(
def _iter_placeholders( def _iter_placeholders(
prompt: list[int], prompt: list[int],
mm_prompt_updates: "MultiModalPromptUpdates", mm_prompt_updates: "MultiModalPromptUpdates",
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
) -> Iterable[PlaceholderFeaturesInfo]: ) -> Iterable[PlaceholderFeaturesInfo]:
""" """
Yield each set of placeholder tokens found in `prompt`. Yield each set of placeholder tokens found in `prompt`.
...@@ -909,7 +910,7 @@ def _iter_placeholders( ...@@ -909,7 +910,7 @@ def _iter_placeholders(
def find_mm_placeholders( def find_mm_placeholders(
prompt: list[int], prompt: list[int],
mm_prompt_updates: "MultiModalPromptUpdates", mm_prompt_updates: "MultiModalPromptUpdates",
tokenizer: AnyTokenizer, tokenizer: TokenizerLike,
) -> Mapping[str, list[PlaceholderFeaturesInfo]]: ) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
it = _iter_placeholders(prompt, mm_prompt_updates, tokenizer) it = _iter_placeholders(prompt, mm_prompt_updates, tokenizer)
return dict(full_groupby_modality(it)) return dict(full_groupby_modality(it))
...@@ -930,7 +931,7 @@ class InputProcessingContext: ...@@ -930,7 +931,7 @@ class InputProcessingContext:
model_config: ModelConfig model_config: ModelConfig
"""The configuration of the model.""" """The configuration of the model."""
tokenizer: AnyTokenizer tokenizer: TokenizerLike
"""The tokenizer used to tokenize the inputs.""" """The tokenizer used to tokenize the inputs."""
@overload @overload
...@@ -1146,7 +1147,7 @@ class BaseProcessingInfo: ...@@ -1146,7 +1147,7 @@ class BaseProcessingInfo:
def model_id(self) -> str: def model_id(self) -> str:
return self.ctx.model_config.model return self.ctx.model_config.model
def get_tokenizer(self) -> AnyTokenizer: def get_tokenizer(self) -> TokenizerLike:
return self.ctx.tokenizer return self.ctx.tokenizer
def get_hf_config(self) -> PretrainedConfig: def get_hf_config(self) -> PretrainedConfig:
......
...@@ -6,7 +6,8 @@ from typing import TYPE_CHECKING, Generic, Protocol, TypeVar, cast ...@@ -6,7 +6,8 @@ from typing import TYPE_CHECKING, Generic, Protocol, TypeVar, cast
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer, cached_tokenizer_from_config from vllm.tokenizers import TokenizerLike
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
from .cache import BaseMultiModalProcessorCache from .cache import BaseMultiModalProcessorCache
from .processing import ( from .processing import (
...@@ -231,17 +232,20 @@ class MultiModalRegistry: ...@@ -231,17 +232,20 @@ class MultiModalRegistry:
def _create_processing_ctx( def _create_processing_ctx(
self, self,
model_config: "ModelConfig", model_config: "ModelConfig",
tokenizer: AnyTokenizer | None = None, tokenizer: TokenizerLike | None = None,
) -> InputProcessingContext: ) -> InputProcessingContext:
if tokenizer is None and not model_config.skip_tokenizer_init: if model_config.skip_tokenizer_init:
tokenizer = cast(TokenizerLike, object())
elif tokenizer is None:
tokenizer = cached_tokenizer_from_config(model_config) tokenizer = cached_tokenizer_from_config(model_config)
return InputProcessingContext(model_config, tokenizer) return InputProcessingContext(model_config, tokenizer)
def _create_processing_info( def _create_processing_info(
self, self,
model_config: "ModelConfig", model_config: "ModelConfig",
*, *,
tokenizer: AnyTokenizer | None = None, tokenizer: TokenizerLike | None = None,
) -> BaseProcessingInfo: ) -> BaseProcessingInfo:
model_cls = self._get_model_cls(model_config) model_cls = self._get_model_cls(model_config)
factories = model_cls._processor_factory factories = model_cls._processor_factory
...@@ -252,7 +256,7 @@ class MultiModalRegistry: ...@@ -252,7 +256,7 @@ class MultiModalRegistry:
self, self,
model_config: "ModelConfig", model_config: "ModelConfig",
*, *,
tokenizer: AnyTokenizer | None = None, tokenizer: TokenizerLike | None = None,
cache: BaseMultiModalProcessorCache | None = None, cache: BaseMultiModalProcessorCache | None = None,
) -> BaseMultiModalProcessor[BaseProcessingInfo]: ) -> BaseMultiModalProcessor[BaseProcessingInfo]:
""" """
......
...@@ -19,12 +19,12 @@ if TYPE_CHECKING: ...@@ -19,12 +19,12 @@ if TYPE_CHECKING:
DeltaMessage, DeltaMessage,
ResponsesRequest, ResponsesRequest,
) )
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
else: else:
ChatCompletionRequest = Any ChatCompletionRequest = Any
DeltaMessage = Any DeltaMessage = Any
ResponsesRequest = Any ResponsesRequest = Any
AnyTokenizer = Any TokenizerLike = Any
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -37,7 +37,7 @@ class ReasoningParser: ...@@ -37,7 +37,7 @@ class ReasoningParser:
It is used to extract reasoning content from the model output. It is used to extract reasoning content from the model output.
""" """
def __init__(self, tokenizer: AnyTokenizer, *args, **kwargs): def __init__(self, tokenizer: TokenizerLike, *args, **kwargs):
self.model_tokenizer = tokenizer self.model_tokenizer = tokenizer
@cached_property @cached_property
......
...@@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Any ...@@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Any
from vllm.entrypoints.openai.protocol import DeltaMessage from vllm.entrypoints.openai.protocol import DeltaMessage
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser from vllm.reasoning.abs_reasoning_parsers import ReasoningParser
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.entrypoints.openai.protocol import ( from vllm.entrypoints.openai.protocol import (
...@@ -43,7 +43,7 @@ class BaseThinkingReasoningParser(ReasoningParser): ...@@ -43,7 +43,7 @@ class BaseThinkingReasoningParser(ReasoningParser):
"""The token that ends reasoning content.""" """The token that ends reasoning content."""
raise NotImplementedError raise NotImplementedError
def __init__(self, tokenizer: AnyTokenizer, *args, **kwargs): def __init__(self, tokenizer: TokenizerLike, *args, **kwargs):
super().__init__(tokenizer, *args, **kwargs) super().__init__(tokenizer, *args, **kwargs)
if not self.model_tokenizer: if not self.model_tokenizer:
......
...@@ -11,7 +11,7 @@ from vllm.entrypoints.openai.protocol import ( ...@@ -11,7 +11,7 @@ from vllm.entrypoints.openai.protocol import (
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser from vllm.reasoning.abs_reasoning_parsers import ReasoningParser
from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -37,7 +37,7 @@ class MiniMaxM2AppendThinkReasoningParser(ReasoningParser): ...@@ -37,7 +37,7 @@ class MiniMaxM2AppendThinkReasoningParser(ReasoningParser):
Reasoning parser for MiniMax M2 model. Reasoning parser for MiniMax M2 model.
""" """
def __init__(self, tokenizer: AnyTokenizer, *args, **kwargs): def __init__(self, tokenizer: TokenizerLike, *args, **kwargs):
super().__init__(tokenizer, *args, **kwargs) super().__init__(tokenizer, *args, **kwargs)
self.end_token_id = self.vocab.get("</think>") self.end_token_id = self.vocab.get("</think>")
......
...@@ -6,7 +6,7 @@ from functools import cached_property ...@@ -6,7 +6,7 @@ from functools import cached_property
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.reasoning import ReasoningParser from vllm.reasoning import ReasoningParser
from vllm.reasoning.deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser from vllm.reasoning.deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer from vllm.tokenizers import MistralTokenizer
logger = init_logger(__name__) logger = init_logger(__name__)
......
...@@ -9,7 +9,7 @@ from typing import TYPE_CHECKING ...@@ -9,7 +9,7 @@ from typing import TYPE_CHECKING
import regex as re import regex as re
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
from vllm.entrypoints.openai.protocol import ( from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest, ChatCompletionRequest,
...@@ -220,7 +220,7 @@ class Olmo3ReasoningParser(ReasoningParser): ...@@ -220,7 +220,7 @@ class Olmo3ReasoningParser(ReasoningParser):
token is missing from generation. token is missing from generation.
""" """
def __init__(self, tokenizer: "AnyTokenizer", *args, **kwargs): def __init__(self, tokenizer: "TokenizerLike", *args, **kwargs):
super().__init__(tokenizer, *args, **kwargs) super().__init__(tokenizer, *args, **kwargs)
self.think_start = r"<think>" self.think_start = r"<think>"
......
...@@ -13,7 +13,7 @@ from pydantic.dataclasses import dataclass ...@@ -13,7 +13,7 @@ from pydantic.dataclasses import dataclass
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logits_process import LogitsProcessor from vllm.logits_process import LogitsProcessor
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.tokenizers import TokenizerLike
from vllm.v1.serial_utils import PydanticMsgspecMixin from vllm.v1.serial_utils import PydanticMsgspecMixin
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -477,7 +477,7 @@ class SamplingParams( ...@@ -477,7 +477,7 @@ class SamplingParams(
eos_ids.update(self.stop_token_ids) eos_ids.update(self.stop_token_ids)
self.stop_token_ids = list(eos_ids) self.stop_token_ids = list(eos_ids)
def update_from_tokenizer(self, tokenizer: AnyTokenizer) -> None: def update_from_tokenizer(self, tokenizer: TokenizerLike) -> None:
if not self.bad_words: if not self.bad_words:
return return
self._bad_words_token_ids = [] self._bad_words_token_ids = []
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from .mistral import MistralTokenizer
from .protocol import TokenizerLike
from .registry import TokenizerRegistry
__all__ = ["TokenizerLike", "MistralTokenizer", "TokenizerRegistry"]
...@@ -4,7 +4,8 @@ ...@@ -4,7 +4,8 @@
from typing import TYPE_CHECKING, Any, cast from typing import TYPE_CHECKING, Any, cast
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer_base import TokenizerBase
from .protocol import TokenizerLike
if TYPE_CHECKING: if TYPE_CHECKING:
from mistral_common.protocol.instruct.request import ( from mistral_common.protocol.instruct.request import (
...@@ -163,7 +164,7 @@ def _tekken_token_to_id(tokenizer: "Tekkenizer", t: str | bytes) -> int: ...@@ -163,7 +164,7 @@ def _tekken_token_to_id(tokenizer: "Tekkenizer", t: str | bytes) -> int:
return tokenizer.unk_id return tokenizer.unk_id
class MistralTokenizer(TokenizerBase): class MistralTokenizer(TokenizerLike):
def __init__(self, tokenizer: "TransformersMistralTokenizer") -> None: def __init__(self, tokenizer: "TransformersMistralTokenizer") -> None:
from mistral_common.protocol.instruct.validator import ValidationMode from mistral_common.protocol.instruct.validator import ValidationMode
from mistral_common.tokens.tokenizers.sentencepiece import ( from mistral_common.tokens.tokenizers.sentencepiece import (
...@@ -270,14 +271,6 @@ class MistralTokenizer(TokenizerBase): ...@@ -270,14 +271,6 @@ class MistralTokenizer(TokenizerBase):
def eos_token_id(self) -> int: def eos_token_id(self) -> int:
return self.tokenizer.eos_id return self.tokenizer.eos_id
@property
def sep_token(self) -> str:
raise NotImplementedError()
@property
def pad_token(self) -> str:
return self.transformers_tokenizer.pad_token
@property @property
def is_fast(self) -> bool: def is_fast(self) -> bool:
return True return True
...@@ -292,11 +285,14 @@ class MistralTokenizer(TokenizerBase): ...@@ -292,11 +285,14 @@ class MistralTokenizer(TokenizerBase):
@property @property
def truncation_side(self) -> str: def truncation_side(self) -> str:
raise NotImplementedError() return self.transformers_tokenizer.truncation_side
def _is_special_token_id(self, token_id: int) -> bool: def _is_special_token_id(self, token_id: int) -> bool:
return token_id in self._special_token_ids_set return token_id in self._special_token_ids_set
def __hash__(self) -> int:
return hash(id(self))
def __len__(self) -> int: def __len__(self) -> int:
return self.vocab_size return self.vocab_size
...@@ -341,17 +337,6 @@ class MistralTokenizer(TokenizerBase): ...@@ -341,17 +337,6 @@ class MistralTokenizer(TokenizerBase):
# Mistral tokenizers have no added vocabulary # Mistral tokenizers have no added vocabulary
return {} return {}
def encode_one(
self,
text: str,
truncation: bool = False,
max_length: int | None = None,
) -> list[int]:
# Mistral Tokenizers should not add special tokens
return self.transformers_tokenizer.encode(
text, add_special_tokens=False, truncation=truncation, max_length=max_length
)
def encode( def encode(
self, self,
text: str, text: str,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment