Unverified Commit 377d10bd authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[VLM][Bugfix] Pass processor kwargs properly on init (#13516)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 52ce14d3
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from functools import lru_cache
from itertools import groupby from itertools import groupby
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Optional, TypeVar, Union from typing import TYPE_CHECKING, Optional, TypeVar, Union
...@@ -13,7 +12,7 @@ from PIL import Image ...@@ -13,7 +12,7 @@ from PIL import Image
import vllm.envs as envs import vllm.envs as envs
from vllm.connections import HTTPConnection, global_http_connection from vllm.connections import HTTPConnection, global_http_connection
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from .audio import AudioMediaIO from .audio import AudioMediaIO
from .base import MediaIO from .base import MediaIO
...@@ -23,8 +22,6 @@ from .video import VideoMediaIO ...@@ -23,8 +22,6 @@ from .video import VideoMediaIO
logger = init_logger(__name__) logger = init_logger(__name__)
cached_get_tokenizer = lru_cache(get_tokenizer)
_M = TypeVar("_M") _M = TypeVar("_M")
if TYPE_CHECKING: if TYPE_CHECKING:
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import base64 import base64
from functools import lru_cache, partial from functools import partial
from io import BytesIO from io import BytesIO
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Optional from typing import TYPE_CHECKING, Any, Dict, Optional
...@@ -12,8 +12,7 @@ from PIL import Image ...@@ -12,8 +12,7 @@ from PIL import Image
from vllm.inputs.registry import InputContext from vllm.inputs.registry import InputContext
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.processor import get_video_processor from vllm.transformers_utils.processor import cached_get_video_processor
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.utils import PlaceholderModule, is_list_of from vllm.utils import PlaceholderModule, is_list_of
from .base import MediaIO, ModalityData from .base import MediaIO, ModalityData
...@@ -30,9 +29,6 @@ except ImportError: ...@@ -30,9 +29,6 @@ except ImportError:
logger = init_logger(__name__) logger = init_logger(__name__)
cached_get_video_processor = lru_cache(get_video_processor)
cached_get_tokenizer = lru_cache(get_tokenizer)
class VideoPlugin(ImagePlugin): class VideoPlugin(ImagePlugin):
"""Plugin for video data.""" """Plugin for video data."""
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from functools import lru_cache from functools import lru_cache
from typing import Any, cast from typing import TYPE_CHECKING, Any, Union, cast
from transformers.processing_utils import ProcessorMixin from transformers.processing_utils import ProcessorMixin
from typing_extensions import TypeVar
if TYPE_CHECKING:
from vllm.config import ModelConfig
_P = TypeVar("_P", bound=ProcessorMixin, default=ProcessorMixin)
class HashableDict(dict):
"""
A dictionary that can be hashed by lru_cache.
"""
# NOTE: pythonic dict is not hashable,
# we override on it directly for simplicity
def __hash__(self) -> int: # type: ignore[override]
return hash(frozenset(self.items()))
def _merge_mm_kwargs(model_config: "ModelConfig", **kwargs):
base_kwargs = model_config.mm_processor_kwargs
if base_kwargs is None:
base_kwargs = {}
merged_kwargs = {**base_kwargs, **kwargs}
# NOTE: Pythonic dict is not hashable and will raise unhashable type
# error when calling `cached_get_processor`, therefore we need to
# wrap it to a hashable dict.
for key, value in merged_kwargs.items():
if isinstance(value, dict):
merged_kwargs[key] = HashableDict(value)
return merged_kwargs
def get_processor( def get_processor(
processor_name: str, processor_name: str,
*args: Any, *args: Any,
trust_remote_code: bool = False, trust_remote_code: bool = False,
processor_cls: type[ProcessorMixin] = ProcessorMixin, processor_cls: Union[type[_P], tuple[type[_P], ...]] = ProcessorMixin,
**kwargs: Any, **kwargs: Any,
): ) -> _P:
"""Load a processor for the given model name via HuggingFace.""" """Load a processor for the given model name via HuggingFace."""
# don't put this import at the top level # don't put this import at the top level
# it will call torch.cuda.device_count() # it will call torch.cuda.device_count()
from transformers import AutoProcessor from transformers import AutoProcessor
processor_factory = (AutoProcessor processor_factory = (AutoProcessor if processor_cls == ProcessorMixin or
if processor_cls == ProcessorMixin else processor_cls) isinstance(processor_cls, tuple) else processor_cls)
try: try:
processor = processor_factory.from_pretrained( processor = processor_factory.from_pretrained(
...@@ -43,12 +77,30 @@ def get_processor( ...@@ -43,12 +77,30 @@ def get_processor(
else: else:
raise e raise e
return cast(ProcessorMixin, processor) if not isinstance(processor, processor_cls):
raise TypeError("Invalid type of HuggingFace processor. "
f"Expected type: {processor_cls}, but "
f"found type: {type(processor)}")
return processor
cached_get_processor = lru_cache(get_processor) cached_get_processor = lru_cache(get_processor)
def cached_processor_from_config(
model_config: "ModelConfig",
processor_cls: Union[type[_P], tuple[type[_P], ...]] = ProcessorMixin,
**kwargs: Any,
) -> _P:
return cached_get_processor(
model_config.model,
trust_remote_code=model_config.trust_remote_code,
processor_cls=processor_cls, # type: ignore[arg-type]
**_merge_mm_kwargs(model_config, **kwargs),
)
def get_image_processor( def get_image_processor(
processor_name: str, processor_name: str,
*args: Any, *args: Any,
...@@ -85,6 +137,20 @@ def get_image_processor( ...@@ -85,6 +137,20 @@ def get_image_processor(
return cast(BaseImageProcessor, processor) return cast(BaseImageProcessor, processor)
cached_get_image_processor = lru_cache(get_image_processor)
def cached_image_processor_from_config(
model_config: "ModelConfig",
**kwargs: Any,
):
return cached_get_image_processor(
model_config.model,
trust_remote_code=model_config.trust_remote_code,
**_merge_mm_kwargs(model_config, **kwargs),
)
def get_video_processor( def get_video_processor(
processor_name: str, processor_name: str,
*args: Any, *args: Any,
...@@ -104,3 +170,17 @@ def get_video_processor( ...@@ -104,3 +170,17 @@ def get_video_processor(
) )
return cast(BaseImageProcessor, processor.video_processor) return cast(BaseImageProcessor, processor.video_processor)
cached_get_video_processor = lru_cache(get_video_processor)
def cached_video_processor_from_config(
model_config: "ModelConfig",
**kwargs: Any,
):
return cached_get_video_processor(
model_config.model,
trust_remote_code=model_config.trust_remote_code,
**_merge_mm_kwargs(model_config, **kwargs),
)
...@@ -3,9 +3,10 @@ ...@@ -3,9 +3,10 @@
import contextlib import contextlib
import os import os
import warnings import warnings
from functools import lru_cache
from pathlib import Path from pathlib import Path
from types import MethodType from types import MethodType
from typing import Optional, Union from typing import TYPE_CHECKING, Any, Optional, Union
import huggingface_hub import huggingface_hub
from transformers import (AutoTokenizer, PreTrainedTokenizer, from transformers import (AutoTokenizer, PreTrainedTokenizer,
...@@ -20,6 +21,9 @@ from vllm.transformers_utils.tokenizers import MistralTokenizer ...@@ -20,6 +21,9 @@ from vllm.transformers_utils.tokenizers import MistralTokenizer
from vllm.transformers_utils.utils import check_gguf_file from vllm.transformers_utils.utils import check_gguf_file
from vllm.utils import make_async from vllm.utils import make_async
if TYPE_CHECKING:
from vllm.config import ModelConfig
logger = init_logger(__name__) logger = init_logger(__name__)
AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast, AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast,
...@@ -232,6 +236,22 @@ def get_tokenizer( ...@@ -232,6 +236,22 @@ def get_tokenizer(
return tokenizer return tokenizer
cached_get_tokenizer = lru_cache(get_tokenizer)
def cached_tokenizer_from_config(
model_config: "ModelConfig",
**kwargs: Any,
):
return cached_get_tokenizer(
model_config.tokenizer,
tokenizer_mode=model_config.tokenizer_mode,
tokenizer_revision=model_config.tokenizer_revision,
trust_remote_code=model_config.trust_remote_code,
**kwargs,
)
def get_lora_tokenizer(lora_request: LoRARequest, *args, def get_lora_tokenizer(lora_request: LoRARequest, *args,
**kwargs) -> Optional[AnyTokenizer]: **kwargs) -> Optional[AnyTokenizer]:
if lora_request is None: if lora_request is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment