Unverified Commit 90d0a74b authored by xinli-centml's avatar xinli-centml Committed by GitHub
Browse files

[Bugfix] Add revision to `transformers.Auto*.from_pretrained` processors (#17948)


Signed-off-by: default avatarXin Li <xin@centml.ai>
parent d74e5f37
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from functools import lru_cache from functools import lru_cache
from typing import TYPE_CHECKING, Any, Union, cast from typing import TYPE_CHECKING, Any, Optional, Union, cast
from transformers.processing_utils import ProcessorMixin from transformers.processing_utils import ProcessorMixin
from typing_extensions import TypeVar from typing_extensions import TypeVar
...@@ -54,6 +54,7 @@ def _merge_mm_kwargs(model_config: "ModelConfig", **kwargs): ...@@ -54,6 +54,7 @@ def _merge_mm_kwargs(model_config: "ModelConfig", **kwargs):
def get_processor( def get_processor(
processor_name: str, processor_name: str,
*args: Any, *args: Any,
revision: Optional[str] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
processor_cls: Union[type[_P], tuple[type[_P], ...]] = ProcessorMixin, processor_cls: Union[type[_P], tuple[type[_P], ...]] = ProcessorMixin,
**kwargs: Any, **kwargs: Any,
...@@ -70,6 +71,7 @@ def get_processor( ...@@ -70,6 +71,7 @@ def get_processor(
processor = processor_factory.from_pretrained( processor = processor_factory.from_pretrained(
processor_name, processor_name,
*args, *args,
revision=revision,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
**kwargs, **kwargs,
) )
...@@ -106,6 +108,7 @@ def cached_processor_from_config( ...@@ -106,6 +108,7 @@ def cached_processor_from_config(
) -> _P: ) -> _P:
return cached_get_processor( return cached_get_processor(
model_config.model, model_config.model,
revision=model_config.revision,
trust_remote_code=model_config.trust_remote_code, trust_remote_code=model_config.trust_remote_code,
processor_cls=processor_cls, # type: ignore[arg-type] processor_cls=processor_cls, # type: ignore[arg-type]
**_merge_mm_kwargs(model_config, **kwargs), **_merge_mm_kwargs(model_config, **kwargs),
...@@ -115,6 +118,7 @@ def cached_processor_from_config( ...@@ -115,6 +118,7 @@ def cached_processor_from_config(
def get_feature_extractor( def get_feature_extractor(
processor_name: str, processor_name: str,
*args: Any, *args: Any,
revision: Optional[str] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
**kwargs: Any, **kwargs: Any,
): ):
...@@ -128,6 +132,7 @@ def get_feature_extractor( ...@@ -128,6 +132,7 @@ def get_feature_extractor(
feature_extractor = AutoFeatureExtractor.from_pretrained( feature_extractor = AutoFeatureExtractor.from_pretrained(
processor_name, processor_name,
*args, *args,
revision=revision,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
**kwargs) **kwargs)
except ValueError as e: except ValueError as e:
...@@ -156,6 +161,7 @@ def cached_feature_extractor_from_config( ...@@ -156,6 +161,7 @@ def cached_feature_extractor_from_config(
): ):
return cached_get_feature_extractor( return cached_get_feature_extractor(
model_config.model, model_config.model,
revision=model_config.revision,
trust_remote_code=model_config.trust_remote_code, trust_remote_code=model_config.trust_remote_code,
**_merge_mm_kwargs(model_config, **kwargs), **_merge_mm_kwargs(model_config, **kwargs),
) )
...@@ -164,6 +170,7 @@ def cached_feature_extractor_from_config( ...@@ -164,6 +170,7 @@ def cached_feature_extractor_from_config(
def get_image_processor( def get_image_processor(
processor_name: str, processor_name: str,
*args: Any, *args: Any,
revision: Optional[str] = None,
trust_remote_code: bool = False, trust_remote_code: bool = False,
**kwargs: Any, **kwargs: Any,
): ):
...@@ -177,6 +184,7 @@ def get_image_processor( ...@@ -177,6 +184,7 @@ def get_image_processor(
processor = AutoImageProcessor.from_pretrained( processor = AutoImageProcessor.from_pretrained(
processor_name, processor_name,
*args, *args,
revision=revision,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
**kwargs) **kwargs)
except ValueError as e: except ValueError as e:
...@@ -206,6 +214,7 @@ def cached_image_processor_from_config( ...@@ -206,6 +214,7 @@ def cached_image_processor_from_config(
): ):
return cached_get_image_processor( return cached_get_image_processor(
model_config.model, model_config.model,
revision=model_config.revision,
trust_remote_code=model_config.trust_remote_code, trust_remote_code=model_config.trust_remote_code,
**_merge_mm_kwargs(model_config, **kwargs), **_merge_mm_kwargs(model_config, **kwargs),
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment