Unverified Commit 653591d5 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Chore] Move tokenizer initialization methods (#29793)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent e2fbfc95
......@@ -73,12 +73,9 @@ from vllm.multimodal.processing import (
)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.tokenizers import TokenizerLike
from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config
from vllm.transformers_utils.configs.radio import RadioConfig
from vllm.transformers_utils.tokenizer import (
cached_tokenizer_from_config,
encode_tokens,
)
from vllm.transformers_utils.tokenizer import encode_tokens
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .utils import _merge_multimodal_embeddings
......
......@@ -59,8 +59,7 @@ from vllm.multimodal.processing import (
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.tokenizers import MistralTokenizer
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
from vllm.tokenizers import MistralTokenizer, cached_tokenizer_from_config
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
......
......@@ -51,8 +51,7 @@ from vllm.multimodal.processing import (
)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.tokenizers import MistralTokenizer
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
from vllm.tokenizers import MistralTokenizer, cached_tokenizer_from_config
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsTranscription
from .utils import init_vllm_registered_model, maybe_prefix
......
......@@ -48,7 +48,7 @@ from vllm.multimodal.processing import (
PromptUpdate,
)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.transformers_utils.processor import cached_get_processor
from vllm.transformers_utils.processor import cached_processor_from_config
from vllm.utils.jsontree import json_map_leaves
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from vllm.utils.torch_utils import set_default_torch_dtype
......@@ -850,7 +850,7 @@ class WhisperForConditionalGeneration(
def get_speech_to_text_config(
cls, model_config: ModelConfig, task_type: str
) -> SpeechToTextConfig:
processor = cached_get_processor(model_config.model)
processor = cached_processor_from_config(model_config)
return SpeechToTextConfig(
max_audio_clip_s=processor.feature_extractor.chunk_length,
......@@ -864,7 +864,7 @@ class WhisperForConditionalGeneration(
stt_config: SpeechToTextConfig,
model_config: ModelConfig,
) -> int | None:
processor = cached_get_processor(model_config.model)
processor = cached_processor_from_config(model_config)
hop_length = processor.feature_extractor.hop_length
assert hop_length is not None
# NOTE(NickLucche) user can't pass encoder
......
......@@ -6,8 +6,7 @@ from typing import TYPE_CHECKING, Generic, Protocol, TypeVar, cast
from vllm.config.multimodal import BaseDummyOptions
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config
from .cache import BaseMultiModalProcessorCache
from .processing import (
......
......@@ -4,12 +4,21 @@
from .hf import HfTokenizer
from .mistral import MistralTokenizer
from .protocol import TokenizerLike
from .registry import TokenizerRegistry, get_tokenizer
from .registry import (
TokenizerRegistry,
cached_get_tokenizer,
cached_tokenizer_from_config,
get_tokenizer,
init_tokenizer_from_config,
)
__all__ = [
"TokenizerLike",
"HfTokenizer",
"MistralTokenizer",
"TokenizerRegistry",
"cached_get_tokenizer",
"get_tokenizer",
"cached_tokenizer_from_config",
"init_tokenizer_from_config",
]
......@@ -2,10 +2,12 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import importlib.util
from collections.abc import Callable
from functools import lru_cache
from pathlib import Path
from typing import TypeVar, overload
from typing import TYPE_CHECKING, TypeVar, overload
import huggingface_hub
from typing_extensions import assert_never
import vllm.envs as envs
from vllm.logger import init_logger
......@@ -21,6 +23,9 @@ from vllm.utils.import_utils import resolve_obj_by_qualname
from .protocol import TokenizerLike
if TYPE_CHECKING:
from vllm.config import ModelConfig
logger = init_logger(__name__)
_T = TypeVar("_T", bound=type[TokenizerLike])
......@@ -195,3 +200,34 @@ def get_tokenizer(
)
return tokenizer
cached_get_tokenizer = lru_cache(get_tokenizer)
def cached_tokenizer_from_config(model_config: "ModelConfig", **kwargs):
return cached_get_tokenizer(
model_config.tokenizer,
tokenizer_mode=model_config.tokenizer_mode,
revision=model_config.tokenizer_revision,
trust_remote_code=model_config.trust_remote_code,
**kwargs,
)
def init_tokenizer_from_config(model_config: "ModelConfig"):
runner_type = model_config.runner_type
if runner_type == "generate" or runner_type == "draft":
truncation_side = "left"
elif runner_type == "pooling":
truncation_side = "right"
else:
assert_never(runner_type)
return get_tokenizer(
model_config.tokenizer,
tokenizer_mode=model_config.tokenizer_mode,
trust_remote_code=model_config.trust_remote_code,
revision=model_config.tokenizer_revision,
truncation_side=truncation_side,
)
......@@ -2,17 +2,10 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import warnings
from functools import lru_cache
from typing import TYPE_CHECKING, Any
from typing_extensions import assert_never
from typing import Any
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike, get_tokenizer
if TYPE_CHECKING:
from vllm.config import ModelConfig
from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__)
......@@ -28,18 +21,54 @@ def __getattr__(name: str):
)
return TokenizerLike
if name == "get_cached_tokenizer":
from vllm.tokenizers.hf import get_cached_tokenizer
if name == "get_tokenizer":
from vllm.tokenizers import get_tokenizer
warnings.warn(
"`vllm.transformers_utils.tokenizer.get_tokenizer` "
"has been moved to `vllm.tokenizers.get_tokenizer`. "
"The old name will be removed in v0.13.",
DeprecationWarning,
stacklevel=2,
)
return get_tokenizer
if name == "cached_get_tokenizer":
from vllm.tokenizers import cached_get_tokenizer
warnings.warn(
"`vllm.transformers_utils.tokenizer.cached_get_tokenizer` "
"has been moved to `vllm.tokenizers.cached_get_tokenizer`. "
"The old name will be removed in v0.13.",
DeprecationWarning,
stacklevel=2,
)
return cached_get_tokenizer
if name == "cached_tokenizer_from_config":
from vllm.tokenizers import cached_tokenizer_from_config
warnings.warn(
"`vllm.transformers_utils.tokenizer.get_cached_tokenizer` "
"has been moved to `vllm.tokenizers.hf.get_cached_tokenizer`. "
"`vllm.transformers_utils.tokenizer.cached_tokenizer_from_config` "
"has been moved to `vllm.tokenizers.cached_tokenizer_from_config`. "
"The old name will be removed in v0.13.",
DeprecationWarning,
stacklevel=2,
)
return get_cached_tokenizer
return cached_tokenizer_from_config
if name == "init_tokenizer_from_configs":
from vllm.tokenizers import init_tokenizer_from_config
warnings.warn(
"`vllm.transformers_utils.tokenizer.init_tokenizer_from_configs` "
"has been moved to `vllm.tokenizers.init_tokenizer_from_config`. "
"The old name will be removed in v0.13.",
DeprecationWarning,
stacklevel=2,
)
return init_tokenizer_from_config
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
......@@ -92,37 +121,3 @@ def encode_tokens(
kw_args["add_special_tokens"] = add_special_tokens
return tokenizer.encode(text, **kw_args)
cached_get_tokenizer = lru_cache(get_tokenizer)
def cached_tokenizer_from_config(
model_config: "ModelConfig",
**kwargs: Any,
):
return cached_get_tokenizer(
model_config.tokenizer,
tokenizer_mode=model_config.tokenizer_mode,
revision=model_config.tokenizer_revision,
trust_remote_code=model_config.trust_remote_code,
**kwargs,
)
def init_tokenizer_from_configs(model_config: "ModelConfig"):
runner_type = model_config.runner_type
if runner_type == "generate" or runner_type == "draft":
truncation_side = "left"
elif runner_type == "pooling":
truncation_side = "right"
else:
assert_never(runner_type)
return get_tokenizer(
model_config.tokenizer,
tokenizer_mode=model_config.tokenizer_mode,
trust_remote_code=model_config.trust_remote_code,
revision=model_config.tokenizer_revision,
truncation_side=truncation_side,
)
......@@ -26,10 +26,9 @@ from vllm.plugins.io_processors import get_io_processor
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.tasks import SupportedTask
from vllm.tokenizers import TokenizerLike
from vllm.tokenizers import TokenizerLike, init_tokenizer_from_config
from vllm.tracing import init_tracer
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
from vllm.transformers_utils.tokenizer import init_tokenizer_from_configs
from vllm.usage.usage_lib import UsageContext
from vllm.utils.async_utils import cancel_task_threadsafe
from vllm.utils.collection_utils import as_list
......@@ -112,7 +111,7 @@ class AsyncLLM(EngineClient):
if self.model_config.skip_tokenizer_init:
tokenizer = None
else:
tokenizer = init_tokenizer_from_configs(self.model_config)
tokenizer = init_tokenizer_from_config(self.model_config)
self.input_processor = InputProcessor(self.vllm_config, tokenizer)
self.io_processor = get_io_processor(
......
......@@ -23,9 +23,8 @@ from vllm.plugins.io_processors import get_io_processor
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.tasks import SupportedTask
from vllm.tokenizers import TokenizerLike
from vllm.tokenizers import TokenizerLike, init_tokenizer_from_config
from vllm.tracing import init_tracer
from vllm.transformers_utils.tokenizer import init_tokenizer_from_configs
from vllm.usage.usage_lib import UsageContext
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.core_client import EngineCoreClient
......@@ -87,7 +86,7 @@ class LLMEngine:
if self.model_config.skip_tokenizer_init:
tokenizer = None
else:
tokenizer = init_tokenizer_from_configs(self.model_config)
tokenizer = init_tokenizer_from_config(self.model_config)
self.input_processor = InputProcessor(self.vllm_config, tokenizer)
self.io_processor = get_io_processor(
......
......@@ -7,7 +7,7 @@ from typing import TYPE_CHECKING
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.reasoning import ReasoningParserManager
from vllm.transformers_utils.tokenizer import init_tokenizer_from_configs
from vllm.tokenizers import init_tokenizer_from_config
from vllm.utils.import_utils import LazyLoader
from vllm.v1.structured_output.backend_guidance import GuidanceBackend
from vllm.v1.structured_output.backend_types import (
......@@ -61,7 +61,7 @@ class StructuredOutputManager:
# of CPUs.
max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2)
self.executor = ThreadPoolExecutor(max_workers=max_workers)
self.tokenizer = init_tokenizer_from_configs(
self.tokenizer = init_tokenizer_from_config(
model_config=self.vllm_config.model_config
)
reasoning_parser = (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment