Unverified Commit 653591d5 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Chore] Move tokenizer initialization methods (#29793)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent e2fbfc95
...@@ -73,12 +73,9 @@ from vllm.multimodal.processing import ( ...@@ -73,12 +73,9 @@ from vllm.multimodal.processing import (
) )
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config
from vllm.transformers_utils.configs.radio import RadioConfig from vllm.transformers_utils.configs.radio import RadioConfig
from vllm.transformers_utils.tokenizer import ( from vllm.transformers_utils.tokenizer import encode_tokens
cached_tokenizer_from_config,
encode_tokens,
)
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .utils import _merge_multimodal_embeddings from .utils import _merge_multimodal_embeddings
......
...@@ -59,8 +59,7 @@ from vllm.multimodal.processing import ( ...@@ -59,8 +59,7 @@ from vllm.multimodal.processing import (
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.tokenizers import MistralTokenizer from vllm.tokenizers import MistralTokenizer, cached_tokenizer_from_config
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
......
...@@ -51,8 +51,7 @@ from vllm.multimodal.processing import ( ...@@ -51,8 +51,7 @@ from vllm.multimodal.processing import (
) )
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.tokenizers import MistralTokenizer from vllm.tokenizers import MistralTokenizer, cached_tokenizer_from_config
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsTranscription from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsTranscription
from .utils import init_vllm_registered_model, maybe_prefix from .utils import init_vllm_registered_model, maybe_prefix
......
...@@ -48,7 +48,7 @@ from vllm.multimodal.processing import ( ...@@ -48,7 +48,7 @@ from vllm.multimodal.processing import (
PromptUpdate, PromptUpdate,
) )
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.transformers_utils.processor import cached_get_processor from vllm.transformers_utils.processor import cached_processor_from_config
from vllm.utils.jsontree import json_map_leaves from vllm.utils.jsontree import json_map_leaves
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
from vllm.utils.torch_utils import set_default_torch_dtype from vllm.utils.torch_utils import set_default_torch_dtype
...@@ -850,7 +850,7 @@ class WhisperForConditionalGeneration( ...@@ -850,7 +850,7 @@ class WhisperForConditionalGeneration(
def get_speech_to_text_config( def get_speech_to_text_config(
cls, model_config: ModelConfig, task_type: str cls, model_config: ModelConfig, task_type: str
) -> SpeechToTextConfig: ) -> SpeechToTextConfig:
processor = cached_get_processor(model_config.model) processor = cached_processor_from_config(model_config)
return SpeechToTextConfig( return SpeechToTextConfig(
max_audio_clip_s=processor.feature_extractor.chunk_length, max_audio_clip_s=processor.feature_extractor.chunk_length,
...@@ -864,7 +864,7 @@ class WhisperForConditionalGeneration( ...@@ -864,7 +864,7 @@ class WhisperForConditionalGeneration(
stt_config: SpeechToTextConfig, stt_config: SpeechToTextConfig,
model_config: ModelConfig, model_config: ModelConfig,
) -> int | None: ) -> int | None:
processor = cached_get_processor(model_config.model) processor = cached_processor_from_config(model_config)
hop_length = processor.feature_extractor.hop_length hop_length = processor.feature_extractor.hop_length
assert hop_length is not None assert hop_length is not None
# NOTE(NickLucche) user can't pass encoder # NOTE(NickLucche) user can't pass encoder
......
...@@ -6,8 +6,7 @@ from typing import TYPE_CHECKING, Generic, Protocol, TypeVar, cast ...@@ -6,8 +6,7 @@ from typing import TYPE_CHECKING, Generic, Protocol, TypeVar, cast
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
from .cache import BaseMultiModalProcessorCache from .cache import BaseMultiModalProcessorCache
from .processing import ( from .processing import (
......
...@@ -4,12 +4,21 @@ ...@@ -4,12 +4,21 @@
from .hf import HfTokenizer from .hf import HfTokenizer
from .mistral import MistralTokenizer from .mistral import MistralTokenizer
from .protocol import TokenizerLike from .protocol import TokenizerLike
from .registry import TokenizerRegistry, get_tokenizer from .registry import (
TokenizerRegistry,
cached_get_tokenizer,
cached_tokenizer_from_config,
get_tokenizer,
init_tokenizer_from_config,
)
__all__ = [ __all__ = [
"TokenizerLike", "TokenizerLike",
"HfTokenizer", "HfTokenizer",
"MistralTokenizer", "MistralTokenizer",
"TokenizerRegistry", "TokenizerRegistry",
"cached_get_tokenizer",
"get_tokenizer", "get_tokenizer",
"cached_tokenizer_from_config",
"init_tokenizer_from_config",
] ]
...@@ -2,10 +2,12 @@ ...@@ -2,10 +2,12 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import importlib.util import importlib.util
from collections.abc import Callable from collections.abc import Callable
from functools import lru_cache
from pathlib import Path from pathlib import Path
from typing import TypeVar, overload from typing import TYPE_CHECKING, TypeVar, overload
import huggingface_hub import huggingface_hub
from typing_extensions import assert_never
import vllm.envs as envs import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -21,6 +23,9 @@ from vllm.utils.import_utils import resolve_obj_by_qualname ...@@ -21,6 +23,9 @@ from vllm.utils.import_utils import resolve_obj_by_qualname
from .protocol import TokenizerLike from .protocol import TokenizerLike
if TYPE_CHECKING:
from vllm.config import ModelConfig
logger = init_logger(__name__) logger = init_logger(__name__)
_T = TypeVar("_T", bound=type[TokenizerLike]) _T = TypeVar("_T", bound=type[TokenizerLike])
...@@ -195,3 +200,34 @@ def get_tokenizer( ...@@ -195,3 +200,34 @@ def get_tokenizer(
) )
return tokenizer return tokenizer
cached_get_tokenizer = lru_cache(get_tokenizer)
def cached_tokenizer_from_config(model_config: "ModelConfig", **kwargs):
return cached_get_tokenizer(
model_config.tokenizer,
tokenizer_mode=model_config.tokenizer_mode,
revision=model_config.tokenizer_revision,
trust_remote_code=model_config.trust_remote_code,
**kwargs,
)
def init_tokenizer_from_config(model_config: "ModelConfig"):
runner_type = model_config.runner_type
if runner_type == "generate" or runner_type == "draft":
truncation_side = "left"
elif runner_type == "pooling":
truncation_side = "right"
else:
assert_never(runner_type)
return get_tokenizer(
model_config.tokenizer,
tokenizer_mode=model_config.tokenizer_mode,
trust_remote_code=model_config.trust_remote_code,
revision=model_config.tokenizer_revision,
truncation_side=truncation_side,
)
...@@ -2,17 +2,10 @@ ...@@ -2,17 +2,10 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import warnings import warnings
from functools import lru_cache from typing import Any
from typing import TYPE_CHECKING, Any
from typing_extensions import assert_never
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike, get_tokenizer from vllm.tokenizers import TokenizerLike
if TYPE_CHECKING:
from vllm.config import ModelConfig
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -28,18 +21,54 @@ def __getattr__(name: str): ...@@ -28,18 +21,54 @@ def __getattr__(name: str):
) )
return TokenizerLike return TokenizerLike
if name == "get_cached_tokenizer": if name == "get_tokenizer":
from vllm.tokenizers.hf import get_cached_tokenizer from vllm.tokenizers import get_tokenizer
warnings.warn(
"`vllm.transformers_utils.tokenizer.get_tokenizer` "
"has been moved to `vllm.tokenizers.get_tokenizer`. "
"The old name will be removed in v0.13.",
DeprecationWarning,
stacklevel=2,
)
return get_tokenizer
if name == "cached_get_tokenizer":
from vllm.tokenizers import cached_get_tokenizer
warnings.warn(
"`vllm.transformers_utils.tokenizer.cached_get_tokenizer` "
"has been moved to `vllm.tokenizers.cached_get_tokenizer`. "
"The old name will be removed in v0.13.",
DeprecationWarning,
stacklevel=2,
)
return cached_get_tokenizer
if name == "cached_tokenizer_from_config":
from vllm.tokenizers import cached_tokenizer_from_config
warnings.warn(
"`vllm.transformers_utils.tokenizer.cached_tokenizer_from_config` "
"has been moved to `vllm.tokenizers.cached_tokenizer_from_config`. "
"The old name will be removed in v0.13.",
DeprecationWarning,
stacklevel=2,
)
return cached_tokenizer_from_config
if name == "init_tokenizer_from_configs":
from vllm.tokenizers import init_tokenizer_from_config
warnings.warn( warnings.warn(
"`vllm.transformers_utils.tokenizer.get_cached_tokenizer` " "`vllm.transformers_utils.tokenizer.init_tokenizer_from_configs` "
"has been moved to `vllm.tokenizers.hf.get_cached_tokenizer`. " "has been moved to `vllm.tokenizers.init_tokenizer_from_config`. "
"The old name will be removed in v0.13.", "The old name will be removed in v0.13.",
DeprecationWarning, DeprecationWarning,
stacklevel=2, stacklevel=2,
) )
return get_cached_tokenizer return init_tokenizer_from_config
raise AttributeError(f"module {__name__!r} has no attribute {name!r}") raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
...@@ -92,37 +121,3 @@ def encode_tokens( ...@@ -92,37 +121,3 @@ def encode_tokens(
kw_args["add_special_tokens"] = add_special_tokens kw_args["add_special_tokens"] = add_special_tokens
return tokenizer.encode(text, **kw_args) return tokenizer.encode(text, **kw_args)
cached_get_tokenizer = lru_cache(get_tokenizer)
def cached_tokenizer_from_config(
model_config: "ModelConfig",
**kwargs: Any,
):
return cached_get_tokenizer(
model_config.tokenizer,
tokenizer_mode=model_config.tokenizer_mode,
revision=model_config.tokenizer_revision,
trust_remote_code=model_config.trust_remote_code,
**kwargs,
)
def init_tokenizer_from_configs(model_config: "ModelConfig"):
runner_type = model_config.runner_type
if runner_type == "generate" or runner_type == "draft":
truncation_side = "left"
elif runner_type == "pooling":
truncation_side = "right"
else:
assert_never(runner_type)
return get_tokenizer(
model_config.tokenizer,
tokenizer_mode=model_config.tokenizer_mode,
trust_remote_code=model_config.trust_remote_code,
revision=model_config.tokenizer_revision,
truncation_side=truncation_side,
)
...@@ -26,10 +26,9 @@ from vllm.plugins.io_processors import get_io_processor ...@@ -26,10 +26,9 @@ from vllm.plugins.io_processors import get_io_processor
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.tasks import SupportedTask from vllm.tasks import SupportedTask
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike, init_tokenizer_from_config
from vllm.tracing import init_tracer from vllm.tracing import init_tracer
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
from vllm.transformers_utils.tokenizer import init_tokenizer_from_configs
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils.async_utils import cancel_task_threadsafe from vllm.utils.async_utils import cancel_task_threadsafe
from vllm.utils.collection_utils import as_list from vllm.utils.collection_utils import as_list
...@@ -112,7 +111,7 @@ class AsyncLLM(EngineClient): ...@@ -112,7 +111,7 @@ class AsyncLLM(EngineClient):
if self.model_config.skip_tokenizer_init: if self.model_config.skip_tokenizer_init:
tokenizer = None tokenizer = None
else: else:
tokenizer = init_tokenizer_from_configs(self.model_config) tokenizer = init_tokenizer_from_config(self.model_config)
self.input_processor = InputProcessor(self.vllm_config, tokenizer) self.input_processor = InputProcessor(self.vllm_config, tokenizer)
self.io_processor = get_io_processor( self.io_processor = get_io_processor(
......
...@@ -23,9 +23,8 @@ from vllm.plugins.io_processors import get_io_processor ...@@ -23,9 +23,8 @@ from vllm.plugins.io_processors import get_io_processor
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.tasks import SupportedTask from vllm.tasks import SupportedTask
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike, init_tokenizer_from_config
from vllm.tracing import init_tracer from vllm.tracing import init_tracer
from vllm.transformers_utils.tokenizer import init_tokenizer_from_configs
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.core_client import EngineCoreClient from vllm.v1.engine.core_client import EngineCoreClient
...@@ -87,7 +86,7 @@ class LLMEngine: ...@@ -87,7 +86,7 @@ class LLMEngine:
if self.model_config.skip_tokenizer_init: if self.model_config.skip_tokenizer_init:
tokenizer = None tokenizer = None
else: else:
tokenizer = init_tokenizer_from_configs(self.model_config) tokenizer = init_tokenizer_from_config(self.model_config)
self.input_processor = InputProcessor(self.vllm_config, tokenizer) self.input_processor = InputProcessor(self.vllm_config, tokenizer)
self.io_processor = get_io_processor( self.io_processor = get_io_processor(
......
...@@ -7,7 +7,7 @@ from typing import TYPE_CHECKING ...@@ -7,7 +7,7 @@ from typing import TYPE_CHECKING
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.reasoning import ReasoningParserManager from vllm.reasoning import ReasoningParserManager
from vllm.transformers_utils.tokenizer import init_tokenizer_from_configs from vllm.tokenizers import init_tokenizer_from_config
from vllm.utils.import_utils import LazyLoader from vllm.utils.import_utils import LazyLoader
from vllm.v1.structured_output.backend_guidance import GuidanceBackend from vllm.v1.structured_output.backend_guidance import GuidanceBackend
from vllm.v1.structured_output.backend_types import ( from vllm.v1.structured_output.backend_types import (
...@@ -61,7 +61,7 @@ class StructuredOutputManager: ...@@ -61,7 +61,7 @@ class StructuredOutputManager:
# of CPUs. # of CPUs.
max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2) max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2)
self.executor = ThreadPoolExecutor(max_workers=max_workers) self.executor = ThreadPoolExecutor(max_workers=max_workers)
self.tokenizer = init_tokenizer_from_configs( self.tokenizer = init_tokenizer_from_config(
model_config=self.vllm_config.model_config model_config=self.vllm_config.model_config
) )
reasoning_parser = ( reasoning_parser = (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment