Unverified Commit 87c41c26 authored by rongfu.leng's avatar rongfu.leng Committed by GitHub
Browse files

[Bugfix] Fix processor initialization for model from modelscope instead of HF (#27461)


Signed-off-by: default avatarrongfu.leng <rongfu.leng@daocloud.io>
Co-authored-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
parent 65d2cf95
......@@ -16,6 +16,7 @@ from transformers.processing_utils import ProcessorMixin
from transformers.video_processing_utils import BaseVideoProcessor
from typing_extensions import TypeVar
from vllm.transformers_utils.utils import convert_model_repo_to_path
from vllm.utils.func_utils import get_allowed_kwarg_only_overrides
if TYPE_CHECKING:
......@@ -94,8 +95,8 @@ def get_processor(
"""Load a processor for the given model name via HuggingFace."""
if revision is None:
revision = "main"
try:
processor_name = convert_model_repo_to_path(processor_name)
if isinstance(processor_cls, tuple) or processor_cls == ProcessorMixin:
processor = AutoProcessor.from_pretrained(
processor_name,
......@@ -168,6 +169,7 @@ def get_feature_extractor(
"""Load an audio feature extractor for the given model name
via HuggingFace."""
try:
processor_name = convert_model_repo_to_path(processor_name)
feature_extractor = AutoFeatureExtractor.from_pretrained(
processor_name,
*args,
......@@ -217,6 +219,7 @@ def get_image_processor(
):
"""Load an image processor for the given model name via HuggingFace."""
try:
processor_name = convert_model_repo_to_path(processor_name)
processor = AutoImageProcessor.from_pretrained(
processor_name,
*args,
......@@ -268,6 +271,7 @@ def get_video_processor(
):
"""Load a video processor for the given model name via HuggingFace."""
try:
processor_name = convert_model_repo_to_path(processor_name)
processor_cls = processor_cls_overrides or AutoVideoProcessor
processor = processor_cls.from_pretrained(
processor_name,
......
......@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import os
import struct
from functools import cache
from os import PathLike
......@@ -109,3 +110,13 @@ def parse_safetensors_file_metadata(path: str | PathLike) -> dict[str, Any]:
length_of_metadata = struct.unpack("<Q", f.read(8))[0]
metadata = json.loads(f.read(length_of_metadata).decode("utf-8"))
return metadata
def convert_model_repo_to_path(model_repo: str) -> str:
"""When VLLM_USE_MODELSCOPE is True convert a model
repository string to a Path str."""
if not envs.VLLM_USE_MODELSCOPE or Path(model_repo).exists():
return model_repo
from modelscope.utils.file_utils import get_model_cache_root
return os.path.join(get_model_cache_root(), model_repo)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment