"tests/python/vscode:/vscode.git/clone" did not exist on "d3483fe1f11a73b19253b975221a48717d288edf"
Unverified Commit 9c3b58dc authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

Handle deprecated transformer classes (#12517)

* update

* update

* update
parent 74b5fed4
...@@ -33,6 +33,7 @@ from ..utils import ( ...@@ -33,6 +33,7 @@ from ..utils import (
ONNX_WEIGHTS_NAME, ONNX_WEIGHTS_NAME,
SAFETENSORS_WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME,
WEIGHTS_NAME, WEIGHTS_NAME,
_maybe_remap_transformers_class,
deprecate, deprecate,
get_class_from_dynamic_module, get_class_from_dynamic_module,
is_accelerate_available, is_accelerate_available,
...@@ -356,6 +357,11 @@ def maybe_raise_or_warn( ...@@ -356,6 +357,11 @@ def maybe_raise_or_warn(
"""Simple helper method to raise or warn in case incorrect module has been passed""" """Simple helper method to raise or warn in case incorrect module has been passed"""
if not is_pipeline_module: if not is_pipeline_module:
library = importlib.import_module(library_name) library = importlib.import_module(library_name)
# Handle deprecated Transformers classes
if library_name == "transformers":
class_name = _maybe_remap_transformers_class(class_name) or class_name
class_obj = getattr(library, class_name) class_obj = getattr(library, class_name)
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()} class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
...@@ -390,6 +396,11 @@ def simple_get_class_obj(library_name, class_name): ...@@ -390,6 +396,11 @@ def simple_get_class_obj(library_name, class_name):
class_obj = getattr(pipeline_module, class_name) class_obj = getattr(pipeline_module, class_name)
else: else:
library = importlib.import_module(library_name) library = importlib.import_module(library_name)
# Handle deprecated Transformers classes
if library_name == "transformers":
class_name = _maybe_remap_transformers_class(class_name) or class_name
class_obj = getattr(library, class_name) class_obj = getattr(library, class_name)
return class_obj return class_obj
...@@ -416,6 +427,10 @@ def get_class_obj_and_candidates( ...@@ -416,6 +427,10 @@ def get_class_obj_and_candidates(
# else we just import it from the library. # else we just import it from the library.
library = importlib.import_module(library_name) library = importlib.import_module(library_name)
# Handle deprecated Transformers classes
if library_name == "transformers":
class_name = _maybe_remap_transformers_class(class_name) or class_name
class_obj = getattr(library, class_name) class_obj = getattr(library, class_name)
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()} class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
......
...@@ -38,7 +38,7 @@ from .constants import ( ...@@ -38,7 +38,7 @@ from .constants import (
WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME,
WEIGHTS_NAME, WEIGHTS_NAME,
) )
from .deprecation_utils import deprecate from .deprecation_utils import _maybe_remap_transformers_class, deprecate
from .doc_utils import replace_example_docstring from .doc_utils import replace_example_docstring
from .dynamic_modules_utils import get_class_from_dynamic_module from .dynamic_modules_utils import get_class_from_dynamic_module
from .export_utils import export_to_gif, export_to_obj, export_to_ply, export_to_video from .export_utils import export_to_gif, export_to_obj, export_to_ply, export_to_video
......
...@@ -4,6 +4,54 @@ from typing import Any, Dict, Optional, Union ...@@ -4,6 +4,54 @@ from typing import Any, Dict, Optional, Union
from packaging import version from packaging import version
from ..utils import logging
logger = logging.get_logger(__name__)
# Mapping for deprecated Transformers classes to their replacements
# This is used to handle models that reference deprecated class names in their configs
# Reference: https://github.com/huggingface/transformers/issues/40822
# Format: {
# "DeprecatedClassName": {
# "new_class": "NewClassName",
# "transformers_version": (">=", "5.0.0"), # (operation, version) tuple
# }
# }
_TRANSFORMERS_CLASS_REMAPPING = {
"CLIPFeatureExtractor": {
"new_class": "CLIPImageProcessor",
"transformers_version": (">", "4.57.0"),
},
}
def _maybe_remap_transformers_class(class_name: str) -> Optional[str]:
"""
Check if a Transformers class should be remapped to a newer version.
Args:
class_name: The name of the class to check
Returns:
The new class name if remapping should occur, None otherwise
"""
if class_name not in _TRANSFORMERS_CLASS_REMAPPING:
return None
from .import_utils import is_transformers_version
mapping = _TRANSFORMERS_CLASS_REMAPPING[class_name]
operation, required_version = mapping["transformers_version"]
# Only remap if the transformers version meets the requirement
if is_transformers_version(operation, required_version):
new_class = mapping["new_class"]
logger.warning(f"{class_name} appears to have been deprecated in transformers. Using {new_class} instead.")
return mapping["new_class"]
return None
def deprecate(*args, take_from: Optional[Union[Dict, Any]] = None, standard_warn=True, stacklevel=2): def deprecate(*args, take_from: Optional[Union[Dict, Any]] = None, standard_warn=True, stacklevel=2):
from .. import __version__ from .. import __version__
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment