Unverified Commit 34a98427 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Misc] Refactor tokenizer interface (#29693)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent f223ed41
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Protocol
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.transformers_utils.tokenizer_base import TokenizerBase, TokenizerRegistry
from typing_extensions import Self
if TYPE_CHECKING:
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
class TestTokenizer(TokenizerBase):
class TokenizerLike(Protocol):
@classmethod
def from_pretrained(cls, *args, **kwargs) -> "TestTokenizer":
return TestTokenizer()
def from_pretrained(
cls,
pretrained_model_name_or_path: str,
/,
*,
revision: str | None = None,
) -> Self:
raise NotImplementedError
@property
def all_special_tokens(self) -> list[str]:
raise NotImplementedError()
raise NotImplementedError
@property
def all_special_ids(self) -> list[int]:
raise NotImplementedError()
raise NotImplementedError
@property
def bos_token_id(self) -> int:
return 0
raise NotImplementedError
@property
def eos_token_id(self) -> int:
return 1
@property
def sep_token(self) -> str:
raise NotImplementedError()
@property
def pad_token(self) -> str:
raise NotImplementedError()
raise NotImplementedError
@property
def is_fast(self) -> bool:
raise NotImplementedError()
raise NotImplementedError
@property
def vocab_size(self) -> int:
raise NotImplementedError()
raise NotImplementedError
@property
def max_token_id(self) -> int:
raise NotImplementedError()
raise NotImplementedError
@property
def truncation_side(self) -> str:
raise NotImplementedError()
raise NotImplementedError
def __hash__(self) -> int:
return hash(id(self))
def __len__(self) -> int:
return self.vocab_size
def __call__(
self,
......@@ -63,24 +66,22 @@ class TestTokenizer(TokenizerBase):
truncation: bool = False,
max_length: int | None = None,
):
raise NotImplementedError()
raise NotImplementedError
def get_vocab(self) -> dict[str, int]:
raise NotImplementedError()
raise NotImplementedError
def get_added_vocab(self) -> dict[str, int]:
raise NotImplementedError()
raise NotImplementedError
def encode_one(
def encode(
self,
text: str,
truncation: bool = False,
truncation: bool | None = None,
max_length: int | None = None,
add_special_tokens: bool | None = None,
) -> list[int]:
raise NotImplementedError()
def encode(self, text: str, add_special_tokens: bool | None = None) -> list[int]:
raise NotImplementedError()
raise NotImplementedError
def apply_chat_template(
self,
......@@ -88,33 +89,17 @@ class TestTokenizer(TokenizerBase):
tools: list[dict[str, Any]] | None = None,
**kwargs,
) -> list[int]:
raise NotImplementedError()
raise NotImplementedError
def convert_tokens_to_string(self, tokens: list[str]) -> str:
raise NotImplementedError()
raise NotImplementedError
def decode(self, ids: list[int] | int, skip_special_tokens: bool = True) -> str:
raise NotImplementedError()
raise NotImplementedError
def convert_ids_to_tokens(
self,
ids: list[int],
skip_special_tokens: bool = True,
) -> list[str]:
raise NotImplementedError()
def test_customized_tokenizer():
TokenizerRegistry.register(
"test_tokenizer", "tests.tokenization.test_tokenizer_registry", "TestTokenizer"
)
tokenizer = TokenizerRegistry.get_tokenizer("test_tokenizer")
assert isinstance(tokenizer, TestTokenizer)
assert tokenizer.bos_token_id == 0
assert tokenizer.eos_token_id == 1
tokenizer = get_tokenizer("test_tokenizer", tokenizer_mode="custom")
assert isinstance(tokenizer, TestTokenizer)
assert tokenizer.bos_token_id == 0
assert tokenizer.eos_token_id == 1
raise NotImplementedError
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import importlib
from .protocol import TokenizerLike
class TokenizerRegistry:
# Tokenizer name -> (tokenizer module, tokenizer class)
REGISTRY: dict[str, tuple[str, str]] = {}
@staticmethod
def register(name: str, module: str, class_name: str) -> None:
TokenizerRegistry.REGISTRY[name] = (module, class_name)
@staticmethod
def get_tokenizer(
tokenizer_name: str,
*args,
**kwargs,
) -> "TokenizerLike":
tokenizer_cls = TokenizerRegistry.REGISTRY.get(tokenizer_name)
if tokenizer_cls is None:
raise ValueError(f"Tokenizer {tokenizer_name} not found.")
tokenizer_module = importlib.import_module(tokenizer_cls[0])
class_ = getattr(tokenizer_module, tokenizer_cls[1])
return class_.from_pretrained(*args, **kwargs)
......@@ -26,8 +26,9 @@ from transformers.utils import CONFIG_NAME as HF_CONFIG_NAME
from vllm import envs
from vllm.logger import init_logger
from vllm.transformers_utils.config_parser_base import ConfigParserBase
from vllm.transformers_utils.repo_utils import (
from .config_parser_base import ConfigParserBase
from .repo_utils import (
_get_hf_token,
file_or_path_exists,
get_hf_file_to_dict,
......@@ -35,7 +36,7 @@ from vllm.transformers_utils.repo_utils import (
try_get_local_file,
with_retry,
)
from vllm.transformers_utils.utils import (
from .utils import (
check_gguf_file,
is_gguf,
is_remote_gguf,
......
......@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from .tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
def _replace_none_with_empty(tokens: list[str | None]):
......@@ -12,7 +12,7 @@ def _replace_none_with_empty(tokens: list[str | None]):
def _convert_tokens_to_string_with_added_encoders(
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
output_tokens: list[str],
skip_special_tokens: bool,
spaces_between_special_tokens: bool,
......@@ -57,7 +57,7 @@ INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET = 5
def convert_prompt_ids_to_tokens(
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
prompt_ids: list[int],
skip_special_tokens: bool = False,
) -> tuple[list[str], int, int]:
......@@ -81,7 +81,7 @@ def convert_prompt_ids_to_tokens(
def convert_ids_list_to_tokens(
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
token_ids: list[int],
) -> list[str]:
"""Detokenize the input ids individually.
......@@ -108,7 +108,7 @@ def convert_ids_list_to_tokens(
# https://github.com/huggingface/text-generation-inference/blob/v0.9.4/server/text_generation_server/models/model.py#L62C9-L62C15
# under Apache 2.0 license
def detokenize_incrementally(
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
all_input_ids: list[int],
prev_tokens: list[str] | None,
prefix_offset: int,
......
......@@ -9,7 +9,8 @@ from gguf.constants import Keys, VisionProjectorType
from transformers import Gemma3Config, PretrainedConfig, SiglipVisionConfig
from vllm.logger import init_logger
from vllm.transformers_utils.repo_utils import list_filtered_repo_files
from .repo_utils import list_filtered_repo_files
logger = init_logger(__name__)
......
......@@ -5,41 +5,48 @@ import contextlib
import copy
import importlib.util
import os
import warnings
from functools import lru_cache
from pathlib import Path
from typing import TYPE_CHECKING, Any, TypeAlias
from typing import TYPE_CHECKING, Any
import huggingface_hub
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
from transformers import AutoTokenizer, PreTrainedTokenizerBase
from typing_extensions import assert_never
from vllm import envs
from vllm.logger import init_logger
from vllm.transformers_utils.config import get_sentence_transformer_tokenizer_config
from vllm.transformers_utils.gguf_utils import get_gguf_file_path_from_hf
from vllm.transformers_utils.repo_utils import list_filtered_repo_files
from vllm.transformers_utils.tokenizers import MistralTokenizer
from vllm.transformers_utils.utils import (
check_gguf_file,
is_gguf,
is_remote_gguf,
split_remote_gguf,
)
from vllm.tokenizers import MistralTokenizer, TokenizerLike, TokenizerRegistry
from .config import get_sentence_transformer_tokenizer_config
from .gguf_utils import get_gguf_file_path_from_hf
from .repo_utils import list_filtered_repo_files
from .utils import check_gguf_file, is_gguf, is_remote_gguf, split_remote_gguf
if TYPE_CHECKING:
from vllm.config import ModelConfig
from vllm.transformers_utils.tokenizer_base import TokenizerBase
else:
ModelConfig = Any
TokenizerBase = Any
logger = init_logger(__name__)
AnyTokenizer: TypeAlias = PreTrainedTokenizer | PreTrainedTokenizerFast | TokenizerBase
def __getattr__(name: str):
if name == "AnyTokenizer":
warnings.warn(
"`vllm.transformers_utils.tokenizer.AnyTokenizer` has been moved to "
"`vllm.tokenizers.TokenizerLike`. "
"The old name will be removed in v0.13.",
DeprecationWarning,
stacklevel=2,
)
return TokenizerLike
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
def decode_tokens(
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
token_ids: list[int],
*,
skip_special_tokens: bool | None = None,
......@@ -58,7 +65,7 @@ def decode_tokens(
def encode_tokens(
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
text: str,
*,
truncation: bool | None = None,
......@@ -86,7 +93,7 @@ def encode_tokens(
return tokenizer.encode(text, **kw_args)
def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
def get_cached_tokenizer(tokenizer: TokenizerLike) -> TokenizerLike:
"""
By default, transformers will recompute multiple tokenizer properties
each time they are called, leading to a significant slowdown.
......@@ -144,7 +151,7 @@ def get_tokenizer(
revision: str | None = None,
download_dir: str | None = None,
**kwargs,
) -> AnyTokenizer:
) -> TokenizerLike:
"""Gets a tokenizer for the given model name via HuggingFace or ModelScope."""
if envs.VLLM_USE_MODELSCOPE:
# download model from ModelScope hub,
......@@ -206,15 +213,13 @@ def get_tokenizer(
if len(files_list) > 0:
tokenizer_mode = "mistral"
tokenizer: AnyTokenizer
tokenizer: TokenizerLike
if tokenizer_mode == "mistral":
logger.debug_once(f"Loading MistralTokenizer from {tokenizer_name}")
tokenizer = MistralTokenizer.from_pretrained(
str(tokenizer_name), revision=revision
)
elif tokenizer_mode == "custom":
from vllm.transformers_utils.tokenizer_base import TokenizerRegistry
logger.debug_once(f"Loading CustomTokenizer from {tokenizer_name}")
tokenizer = TokenizerRegistry.get_tokenizer(
str(tokenizer_name),
......@@ -260,12 +265,13 @@ def get_tokenizer(
if isinstance(encoder_config, dict) and encoder_config.get(
"do_lower_case", False
):
assert isinstance(tokenizer, PreTrainedTokenizerBase)
special_tokens_map = {
k: v.lower() for k, v in tokenizer.special_tokens_map.items()
}
tokenizer.add_special_tokens(special_tokens_map)
if not isinstance(tokenizer, PreTrainedTokenizerFast):
if not tokenizer.is_fast:
logger.warning(
"Using a slow tokenizer. This might cause a significant "
"slowdown. Consider using a fast tokenizer instead."
......@@ -279,7 +285,7 @@ cached_get_tokenizer = lru_cache(get_tokenizer)
def cached_tokenizer_from_config(
model_config: ModelConfig,
model_config: "ModelConfig",
**kwargs: Any,
):
return cached_get_tokenizer(
......@@ -291,7 +297,7 @@ def cached_tokenizer_from_config(
)
def init_tokenizer_from_configs(model_config: ModelConfig):
def init_tokenizer_from_configs(model_config: "ModelConfig"):
runner_type = model_config.runner_type
if runner_type == "generate" or runner_type == "draft":
truncation_side = "left"
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import importlib
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any
import warnings
if TYPE_CHECKING:
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
def __getattr__(name: str):
if name == "TokenizerBase":
from vllm.tokenizers import TokenizerLike
class TokenizerBase(ABC):
@property
@abstractmethod
def all_special_tokens(self) -> list[str]:
raise NotImplementedError()
warnings.warn(
"`vllm.transformers_utils.tokenizer_base.TokenizerBase` has been "
"moved to `vllm.tokenizers.TokenizerLike`. "
"The old name will be removed in v0.13.",
DeprecationWarning,
stacklevel=2,
)
@property
@abstractmethod
def all_special_ids(self) -> list[int]:
raise NotImplementedError()
return TokenizerLike
if name == "TokenizerRegistry":
from vllm.tokenizers import TokenizerRegistry
@property
@abstractmethod
def bos_token_id(self) -> int:
raise NotImplementedError()
warnings.warn(
"`vllm.transformers_utils.tokenizer_base.TokenizerRegistry` has been "
"moved to `vllm.tokenizers.TokenizerRegistry`. "
"The old name will be removed in v0.13.",
DeprecationWarning,
stacklevel=2,
)
@property
@abstractmethod
def eos_token_id(self) -> int:
raise NotImplementedError()
return TokenizerRegistry
@property
@abstractmethod
def sep_token(self) -> str:
raise NotImplementedError()
@property
@abstractmethod
def pad_token(self) -> str:
raise NotImplementedError()
@property
@abstractmethod
def is_fast(self) -> bool:
raise NotImplementedError()
@property
@abstractmethod
def vocab_size(self) -> int:
raise NotImplementedError()
@property
@abstractmethod
def max_token_id(self) -> int:
raise NotImplementedError()
@property
@abstractmethod
def truncation_side(self) -> str:
raise NotImplementedError()
def __len__(self) -> int:
return self.vocab_size
@abstractmethod
def __call__(
self,
text: str | list[str] | list[int],
text_pair: str | None = None,
add_special_tokens: bool = False,
truncation: bool = False,
max_length: int | None = None,
):
raise NotImplementedError()
@abstractmethod
def get_vocab(self) -> dict[str, int]:
raise NotImplementedError()
@abstractmethod
def get_added_vocab(self) -> dict[str, int]:
raise NotImplementedError()
@abstractmethod
def encode_one(
self,
text: str,
truncation: bool = False,
max_length: int | None = None,
) -> list[int]:
raise NotImplementedError()
@abstractmethod
def encode(
self,
text: str,
truncation: bool | None = None,
max_length: int | None = None,
add_special_tokens: bool | None = None,
) -> list[int]:
raise NotImplementedError()
@abstractmethod
def apply_chat_template(
self,
messages: list["ChatCompletionMessageParam"],
tools: list[dict[str, Any]] | None = None,
**kwargs,
) -> list[int]:
raise NotImplementedError()
@abstractmethod
def convert_tokens_to_string(self, tokens: list[str]) -> str:
raise NotImplementedError()
@abstractmethod
def decode(self, ids: list[int] | int, skip_special_tokens: bool = True) -> str:
raise NotImplementedError()
@abstractmethod
def convert_ids_to_tokens(
self,
ids: list[int],
skip_special_tokens: bool = True,
) -> list[str]:
raise NotImplementedError()
class TokenizerRegistry:
# Tokenizer name -> (tokenizer module, tokenizer class)
REGISTRY: dict[str, tuple[str, str]] = {}
@staticmethod
def register(name: str, module: str, class_name: str) -> None:
TokenizerRegistry.REGISTRY[name] = (module, class_name)
@staticmethod
def get_tokenizer(
tokenizer_name: str,
*args,
**kwargs,
) -> TokenizerBase:
tokenizer_cls = TokenizerRegistry.REGISTRY.get(tokenizer_name)
if tokenizer_cls is None:
raise ValueError(f"Tokenizer {tokenizer_name} not found.")
tokenizer_module = importlib.import_module(tokenizer_cls[0])
class_ = getattr(tokenizer_module, tokenizer_cls[1])
return class_.from_pretrained(*args, **kwargs)
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from .mistral import (
MistralTokenizer,
maybe_serialize_tool_calls,
truncate_tool_call_ids,
validate_request_params,
)
__all__ = [
"MistralTokenizer",
"maybe_serialize_tool_calls",
"truncate_tool_call_ids",
"validate_request_params",
]
......@@ -26,9 +26,10 @@ from vllm.plugins.io_processors import get_io_processor
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.tasks import SupportedTask
from vllm.tokenizers import TokenizerLike
from vllm.tracing import init_tracer
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
from vllm.transformers_utils.tokenizer import AnyTokenizer, init_tokenizer_from_configs
from vllm.transformers_utils.tokenizer import init_tokenizer_from_configs
from vllm.usage.usage_lib import UsageContext
from vllm.utils.async_utils import cancel_task_threadsafe
from vllm.utils.collection_utils import as_list
......@@ -120,9 +121,10 @@ class AsyncLLM(EngineClient):
)
# OutputProcessor (converts EngineCoreOutputs --> RequestOutput).
stream_interval = self.vllm_config.scheduler_config.stream_interval
self.output_processor = OutputProcessor(
self.tokenizer, log_stats=self.log_stats, stream_interval=stream_interval
self.tokenizer,
log_stats=self.log_stats,
stream_interval=self.vllm_config.scheduler_config.stream_interval,
)
endpoint = self.observability_config.otlp_traces_endpoint
if endpoint is not None:
......@@ -703,17 +705,17 @@ class AsyncLLM(EngineClient):
raise EngineGenerateError() from e
@property
def tokenizer(self) -> AnyTokenizer | None:
def tokenizer(self) -> TokenizerLike | None:
return self.input_processor.tokenizer
@tokenizer.setter
def tokenizer(self, tokenizer: AnyTokenizer | None) -> None:
def tokenizer(self, tokenizer: TokenizerLike | None) -> None:
self.input_processor.tokenizer = tokenizer
async def get_tokenizer(self) -> AnyTokenizer:
async def get_tokenizer(self) -> TokenizerLike:
if self.tokenizer is None:
raise ValueError(
"Unable to get tokenizer because skip_tokenizer_init is True"
"Unable to get tokenizer because `skip_tokenizer_init=True`"
)
return self.tokenizer
......
......@@ -10,7 +10,7 @@ from transformers import PreTrainedTokenizerFast
from vllm.logger import init_logger
from vllm.transformers_utils.detokenizer_utils import (
AnyTokenizer,
TokenizerLike,
convert_prompt_ids_to_tokens,
detokenize_incrementally,
)
......@@ -45,7 +45,7 @@ class IncrementalDetokenizer:
@classmethod
def from_new_request(
cls,
tokenizer: AnyTokenizer | None,
tokenizer: TokenizerLike | None,
request: EngineCoreRequest,
) -> "IncrementalDetokenizer":
assert request.sampling_params is not None
......@@ -256,7 +256,7 @@ class FastIncrementalDetokenizer(BaseIncrementalDetokenizer):
class SlowIncrementalDetokenizer(BaseIncrementalDetokenizer):
def __init__(self, tokenizer: AnyTokenizer, request: EngineCoreRequest):
def __init__(self, tokenizer: TokenizerLike, request: EngineCoreRequest):
super().__init__(request)
self.tokenizer = tokenizer
......
......@@ -19,8 +19,7 @@ from vllm.multimodal.processing import EncDecMultiModalProcessor
from vllm.multimodal.utils import argsort_mm_positions
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
from vllm.tokenizers import MistralTokenizer, TokenizerLike
from vllm.utils import length_from_prompt_token_ids_or_embeds
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.metrics.stats import MultiModalCacheStats
......@@ -40,7 +39,7 @@ class InputProcessor:
def __init__(
self,
vllm_config: VllmConfig,
tokenizer: AnyTokenizer | None,
tokenizer: TokenizerLike | None,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
) -> None:
self.vllm_config = vllm_config
......@@ -62,11 +61,11 @@ class InputProcessor:
)
@property
def tokenizer(self) -> AnyTokenizer | None:
def tokenizer(self) -> TokenizerLike | None:
return self.input_preprocessor.tokenizer
@tokenizer.setter
def tokenizer(self, tokenizer: AnyTokenizer | None) -> None:
def tokenizer(self, tokenizer: TokenizerLike | None) -> None:
self.input_preprocessor.tokenizer = tokenizer
def _validate_logprobs(
......
......@@ -23,8 +23,9 @@ from vllm.plugins.io_processors import get_io_processor
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.tasks import SupportedTask
from vllm.tokenizers import TokenizerLike
from vllm.tracing import init_tracer
from vllm.transformers_utils.tokenizer import AnyTokenizer, init_tokenizer_from_configs
from vllm.transformers_utils.tokenizer import init_tokenizer_from_configs
from vllm.usage.usage_lib import UsageContext
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.core_client import EngineCoreClient
......@@ -95,9 +96,10 @@ class LLMEngine:
)
# OutputProcessor (convert EngineCoreOutputs --> RequestOutput).
stream_interval = self.vllm_config.scheduler_config.stream_interval
self.output_processor = OutputProcessor(
self.tokenizer, log_stats=self.log_stats, stream_interval=stream_interval
self.tokenizer,
log_stats=self.log_stats,
stream_interval=self.vllm_config.scheduler_config.stream_interval,
)
endpoint = self.observability_config.otlp_traces_endpoint
if endpoint is not None:
......@@ -350,17 +352,17 @@ class LLMEngine:
return get_metrics_snapshot()
@property
def tokenizer(self) -> AnyTokenizer | None:
def tokenizer(self) -> TokenizerLike | None:
return self.input_processor.tokenizer
@tokenizer.setter
def tokenizer(self, tokenizer: AnyTokenizer | None) -> None:
def tokenizer(self, tokenizer: TokenizerLike | None) -> None:
self.input_processor.tokenizer = tokenizer
def get_tokenizer(self) -> AnyTokenizer:
def get_tokenizer(self) -> TokenizerLike:
if self.tokenizer is None:
raise ValueError(
"Unable to get tokenizer because skip_tokenizer_init is True"
"Unable to get tokenizer because `skip_tokenizer_init=True`"
)
return self.tokenizer
......
......@@ -13,7 +13,7 @@ from vllm.logprobs import (
create_sample_logprobs,
)
from vllm.transformers_utils.detokenizer_utils import (
AnyTokenizer,
TokenizerLike,
convert_ids_list_to_tokens,
)
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest
......@@ -28,7 +28,7 @@ NONES = itertools.repeat(None)
class LogprobsProcessor:
# Tokenizer for this request,
# None if detokenization is disabled.
tokenizer: AnyTokenizer | None
tokenizer: TokenizerLike | None
# Logprobs for this request
logprobs: SampleLogprobs | None
......@@ -40,7 +40,7 @@ class LogprobsProcessor:
@classmethod
def from_new_request(
cls,
tokenizer: AnyTokenizer | None,
tokenizer: TokenizerLike | None,
request: EngineCoreRequest,
) -> "LogprobsProcessor":
sampling_params = request.sampling_params
......
......@@ -15,8 +15,8 @@ from vllm.outputs import (
RequestOutput,
)
from vllm.sampling_params import RequestOutputKind
from vllm.tokenizers import TokenizerLike
from vllm.tracing import SpanAttributes, SpanKind, Tracer, extract_trace_context
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import length_from_prompt_token_ids_or_embeds
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason
from vllm.v1.engine.detokenizer import IncrementalDetokenizer
......@@ -139,7 +139,7 @@ class RequestState:
@classmethod
def from_new_request(
cls,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike | None,
request: EngineCoreRequest,
prompt: str | None,
parent_req: ParentRequest | None,
......@@ -341,7 +341,10 @@ class OutputProcessor:
"""Process EngineCoreOutputs into RequestOutputs."""
def __init__(
self, tokenizer: AnyTokenizer, log_stats: bool, stream_interval: int = 1
self,
tokenizer: TokenizerLike | None,
log_stats: bool,
stream_interval: int = 1,
):
self.log_stats = log_stats
self.tokenizer = tokenizer
......
......@@ -10,10 +10,10 @@ if TYPE_CHECKING:
import torch
from vllm.config import VllmConfig
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
else:
VllmConfig = object
AnyTokenizer = object
TokenizerLike = object
class StructuredOutputOptions(enum.Enum):
......@@ -100,7 +100,7 @@ class StructuredOutputBackend(ABC):
"""Engine-level backend for structured output requests."""
vllm_config: VllmConfig
tokenizer: AnyTokenizer
tokenizer: TokenizerLike
vocab_size: int
@abstractmethod
......
......@@ -10,7 +10,7 @@ import torch
import vllm.envs
from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
from vllm.tokenizers import MistralTokenizer
from vllm.utils.import_utils import LazyLoader
from vllm.v1.structured_output.backend_types import (
StructuredOutputBackend,
......
......@@ -24,7 +24,7 @@ if TYPE_CHECKING:
import transformers.models.gpt2.tokenization_gpt2 as tokenization_gpt2
import xgrammar as xgr
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.tokenizers import TokenizerLike
from vllm.v1.worker.gpu_input_batch import InputBatch
else:
xgr = LazyLoader("xgr", globals(), "xgrammar")
......@@ -36,7 +36,7 @@ else:
"transformers.models.gpt2.tokenization_gpt2",
)
AnyTokenizer = object
TokenizerLike = object
SchedulerOutput = object
InputBatch = object
......@@ -195,7 +195,7 @@ re_replacement_seq = re.compile(r"^.{0,6}�+.{0,6}$")
def _reduced_vocabulary(
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike,
eos_token_id: int,
) -> dict[bytes, list[int]]:
"""Create a map from vocabulary tokens to lists of equivalent token ids.
......@@ -222,7 +222,7 @@ def _reduced_vocabulary(
vocabulary: dict[bytes, list[int]] = {}
empty_token_ids: list[int] = []
for token, token_idx in tokenizer.get_vocab().items():
if token in tokenizer.all_special_tokens: # type: ignore
if token in tokenizer.all_special_tokens:
continue
token_str = convert_token_to_string(token)
......@@ -261,7 +261,7 @@ def _reduced_vocabulary(
return vocabulary
def get_outlines_vocabulary(tokenizer: AnyTokenizer) -> oc.Vocabulary:
def get_outlines_vocabulary(tokenizer: TokenizerLike) -> oc.Vocabulary:
"""Get the `Vocabulary` object for a given tokenizer."""
if hasattr(tokenizer, "_outlines_vocabulary"):
return tokenizer._outlines_vocabulary # type: ignore
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment