Unverified Commit 377d10bd authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[VLM][Bugfix] Pass processor kwargs properly on init (#13516)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 52ce14d3
# SPDX-License-Identifier: Apache-2.0
from functools import lru_cache
from itertools import groupby
from pathlib import Path
from typing import TYPE_CHECKING, Optional, TypeVar, Union
......@@ -13,7 +12,7 @@ from PIL import Image
import vllm.envs as envs
from vllm.connections import HTTPConnection, global_http_connection
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer
from .audio import AudioMediaIO
from .base import MediaIO
......@@ -23,8 +22,6 @@ from .video import VideoMediaIO
logger = init_logger(__name__)
cached_get_tokenizer = lru_cache(get_tokenizer)
_M = TypeVar("_M")
if TYPE_CHECKING:
......
# SPDX-License-Identifier: Apache-2.0
import base64
from functools import lru_cache, partial
from functools import partial
from io import BytesIO
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Optional
......@@ -12,8 +12,7 @@ from PIL import Image
from vllm.inputs.registry import InputContext
from vllm.logger import init_logger
from vllm.transformers_utils.processor import get_video_processor
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.transformers_utils.processor import cached_get_video_processor
from vllm.utils import PlaceholderModule, is_list_of
from .base import MediaIO, ModalityData
......@@ -30,9 +29,6 @@ except ImportError:
logger = init_logger(__name__)
cached_get_video_processor = lru_cache(get_video_processor)
cached_get_tokenizer = lru_cache(get_tokenizer)
class VideoPlugin(ImagePlugin):
"""Plugin for video data."""
......
# SPDX-License-Identifier: Apache-2.0
from functools import lru_cache
from typing import Any, cast
from typing import TYPE_CHECKING, Any, Union, cast
from transformers.processing_utils import ProcessorMixin
from typing_extensions import TypeVar
if TYPE_CHECKING:
from vllm.config import ModelConfig
_P = TypeVar("_P", bound=ProcessorMixin, default=ProcessorMixin)
class HashableDict(dict):
"""
A dictionary that can be hashed by lru_cache.
"""
# NOTE: pythonic dict is not hashable,
# we override on it directly for simplicity
def __hash__(self) -> int: # type: ignore[override]
return hash(frozenset(self.items()))
def _merge_mm_kwargs(model_config: "ModelConfig", **kwargs):
base_kwargs = model_config.mm_processor_kwargs
if base_kwargs is None:
base_kwargs = {}
merged_kwargs = {**base_kwargs, **kwargs}
# NOTE: Pythonic dict is not hashable and will raise unhashable type
# error when calling `cached_get_processor`, therefore we need to
# wrap it to a hashable dict.
for key, value in merged_kwargs.items():
if isinstance(value, dict):
merged_kwargs[key] = HashableDict(value)
return merged_kwargs
def get_processor(
processor_name: str,
*args: Any,
trust_remote_code: bool = False,
processor_cls: type[ProcessorMixin] = ProcessorMixin,
processor_cls: Union[type[_P], tuple[type[_P], ...]] = ProcessorMixin,
**kwargs: Any,
):
) -> _P:
"""Load a processor for the given model name via HuggingFace."""
# don't put this import at the top level
# it will call torch.cuda.device_count()
from transformers import AutoProcessor
processor_factory = (AutoProcessor
if processor_cls == ProcessorMixin else processor_cls)
processor_factory = (AutoProcessor if processor_cls == ProcessorMixin or
isinstance(processor_cls, tuple) else processor_cls)
try:
processor = processor_factory.from_pretrained(
......@@ -43,12 +77,30 @@ def get_processor(
else:
raise e
return cast(ProcessorMixin, processor)
if not isinstance(processor, processor_cls):
raise TypeError("Invalid type of HuggingFace processor. "
f"Expected type: {processor_cls}, but "
f"found type: {type(processor)}")
return processor
cached_get_processor = lru_cache(get_processor)
def cached_processor_from_config(
model_config: "ModelConfig",
processor_cls: Union[type[_P], tuple[type[_P], ...]] = ProcessorMixin,
**kwargs: Any,
) -> _P:
return cached_get_processor(
model_config.model,
trust_remote_code=model_config.trust_remote_code,
processor_cls=processor_cls, # type: ignore[arg-type]
**_merge_mm_kwargs(model_config, **kwargs),
)
def get_image_processor(
processor_name: str,
*args: Any,
......@@ -85,6 +137,20 @@ def get_image_processor(
return cast(BaseImageProcessor, processor)
cached_get_image_processor = lru_cache(get_image_processor)
def cached_image_processor_from_config(
model_config: "ModelConfig",
**kwargs: Any,
):
return cached_get_image_processor(
model_config.model,
trust_remote_code=model_config.trust_remote_code,
**_merge_mm_kwargs(model_config, **kwargs),
)
def get_video_processor(
processor_name: str,
*args: Any,
......@@ -104,3 +170,17 @@ def get_video_processor(
)
return cast(BaseImageProcessor, processor.video_processor)
cached_get_video_processor = lru_cache(get_video_processor)
def cached_video_processor_from_config(
model_config: "ModelConfig",
**kwargs: Any,
):
return cached_get_video_processor(
model_config.model,
trust_remote_code=model_config.trust_remote_code,
**_merge_mm_kwargs(model_config, **kwargs),
)
......@@ -3,9 +3,10 @@
import contextlib
import os
import warnings
from functools import lru_cache
from pathlib import Path
from types import MethodType
from typing import Optional, Union
from typing import TYPE_CHECKING, Any, Optional, Union
import huggingface_hub
from transformers import (AutoTokenizer, PreTrainedTokenizer,
......@@ -20,6 +21,9 @@ from vllm.transformers_utils.tokenizers import MistralTokenizer
from vllm.transformers_utils.utils import check_gguf_file
from vllm.utils import make_async
if TYPE_CHECKING:
from vllm.config import ModelConfig
logger = init_logger(__name__)
AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast,
......@@ -232,6 +236,22 @@ def get_tokenizer(
return tokenizer
cached_get_tokenizer = lru_cache(get_tokenizer)
def cached_tokenizer_from_config(
model_config: "ModelConfig",
**kwargs: Any,
):
return cached_get_tokenizer(
model_config.tokenizer,
tokenizer_mode=model_config.tokenizer_mode,
tokenizer_revision=model_config.tokenizer_revision,
trust_remote_code=model_config.trust_remote_code,
**kwargs,
)
def get_lora_tokenizer(lora_request: LoRARequest, *args,
**kwargs) -> Optional[AnyTokenizer]:
if lora_request is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment