Unverified Commit f64fa949 authored by Ishan Modi's avatar Ishan Modi Committed by GitHub
Browse files

[Feature] AutoModel can load components using model_index.json (#11401)



* update

* update

* update

* update

* addressed PR comments

* update

* addressed PR comments

* added tests

* addressed PR comments

* updates

* update

* addressed PR comments

* update

* fix style

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
parent 049082e0
...@@ -12,13 +12,16 @@ ...@@ -12,13 +12,16 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import importlib
import os import os
from typing import Optional, Union from typing import Optional, Union
from huggingface_hub.utils import validate_hf_hub_args from huggingface_hub.utils import validate_hf_hub_args
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from ..utils import logging
logger = logging.get_logger(__name__)
class AutoModel(ConfigMixin): class AutoModel(ConfigMixin):
...@@ -152,15 +155,50 @@ class AutoModel(ConfigMixin): ...@@ -152,15 +155,50 @@ class AutoModel(ConfigMixin):
"token": token, "token": token,
"local_files_only": local_files_only, "local_files_only": local_files_only,
"revision": revision, "revision": revision,
"subfolder": subfolder,
} }
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs) library = None
orig_class_name = config["_class_name"] orig_class_name = None
library = importlib.import_module("diffusers") # Always attempt to fetch model_index.json first
try:
cls.config_name = "model_index.json"
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
if subfolder is not None and subfolder in config:
library, orig_class_name = config[subfolder]
load_config_kwargs.update({"subfolder": subfolder})
except EnvironmentError as e:
logger.debug(e)
# Unable to load from model_index.json so fallback to loading from config
if library is None and orig_class_name is None:
cls.config_name = "config.json"
config = cls.load_config(pretrained_model_or_path, subfolder=subfolder, **load_config_kwargs)
if "_class_name" in config:
# If we find a class name in the config, we can try to load the model as a diffusers model
orig_class_name = config["_class_name"]
library = "diffusers"
load_config_kwargs.update({"subfolder": subfolder})
elif "model_type" in config:
orig_class_name = "AutoModel"
library = "transformers"
load_config_kwargs.update({"subfolder": "" if subfolder is None else subfolder})
else:
raise ValueError(f"Couldn't find model associated with the config file at {pretrained_model_or_path}.")
from ..pipelines.pipeline_loading_utils import ALL_IMPORTABLE_CLASSES, get_class_obj_and_candidates
model_cls, _ = get_class_obj_and_candidates(
library_name=library,
class_name=orig_class_name,
importable_classes=ALL_IMPORTABLE_CLASSES,
pipelines=None,
is_pipeline_module=False,
)
model_cls = getattr(library, orig_class_name, None)
if model_cls is None: if model_cls is None:
raise ValueError(f"AutoModel can't find a model linked to {orig_class_name}.") raise ValueError(f"AutoModel can't find a model linked to {orig_class_name}.")
......
...@@ -335,14 +335,14 @@ def get_class_obj_and_candidates( ...@@ -335,14 +335,14 @@ def get_class_obj_and_candidates(
library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name=None, cache_dir=None library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name=None, cache_dir=None
): ):
"""Simple helper method to retrieve class object of module as well as potential parent class objects""" """Simple helper method to retrieve class object of module as well as potential parent class objects"""
component_folder = os.path.join(cache_dir, component_name) component_folder = os.path.join(cache_dir, component_name) if component_name and cache_dir else None
if is_pipeline_module: if is_pipeline_module:
pipeline_module = getattr(pipelines, library_name) pipeline_module = getattr(pipelines, library_name)
class_obj = getattr(pipeline_module, class_name) class_obj = getattr(pipeline_module, class_name)
class_candidates = dict.fromkeys(importable_classes.keys(), class_obj) class_candidates = dict.fromkeys(importable_classes.keys(), class_obj)
elif os.path.isfile(os.path.join(component_folder, library_name + ".py")): elif component_folder and os.path.isfile(os.path.join(component_folder, library_name + ".py")):
# load custom component # load custom component
class_obj = get_class_from_dynamic_module( class_obj = get_class_from_dynamic_module(
component_folder, module_file=library_name + ".py", class_name=class_name component_folder, module_file=library_name + ".py", class_name=class_name
......
import unittest
from unittest.mock import patch
from transformers import CLIPTextModel, LongformerModel
from diffusers.models import AutoModel, UNet2DConditionModel
class TestAutoModel(unittest.TestCase):
@patch(
"diffusers.models.AutoModel.load_config",
side_effect=[EnvironmentError("File not found"), {"_class_name": "UNet2DConditionModel"}],
)
def test_load_from_config_diffusers_with_subfolder(self, mock_load_config):
model = AutoModel.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="unet")
assert isinstance(model, UNet2DConditionModel)
@patch(
"diffusers.models.AutoModel.load_config",
side_effect=[EnvironmentError("File not found"), {"model_type": "clip_text_model"}],
)
def test_load_from_config_transformers_with_subfolder(self, mock_load_config):
model = AutoModel.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="text_encoder")
assert isinstance(model, CLIPTextModel)
def test_load_from_config_without_subfolder(self):
model = AutoModel.from_pretrained("hf-internal-testing/tiny-random-longformer")
assert isinstance(model, LongformerModel)
def test_load_from_model_index(self):
model = AutoModel.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="text_encoder")
assert isinstance(model, CLIPTextModel)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment