"docs/source/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "94e19193ac0afaf1997fb2211a071400b679d92b"
Commit decac197 authored by anton-l's avatar anton-l
Browse files

Merge branch 'main' of github.com:huggingface/diffusers

parents ae73d95e 2fa1d648
...@@ -45,6 +45,10 @@ LOADABLE_CLASSES = { ...@@ -45,6 +45,10 @@ LOADABLE_CLASSES = {
}, },
} }
ALL_IMPORTABLE_CLASSES = {}
for library in LOADABLE_CLASSES:
ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
class DiffusionPipeline(ConfigMixin): class DiffusionPipeline(ConfigMixin):
...@@ -105,10 +109,8 @@ class DiffusionPipeline(ConfigMixin): ...@@ -105,10 +109,8 @@ class DiffusionPipeline(ConfigMixin):
Add docstrings Add docstrings
""" """
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False) resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None) proxies = kwargs.pop("proxies", None)
output_loading_info = kwargs.pop("output_loading_info", False)
local_files_only = kwargs.pop("local_files_only", False) local_files_only = kwargs.pop("local_files_only", False)
use_auth_token = kwargs.pop("use_auth_token", None) use_auth_token = kwargs.pop("use_auth_token", None)
...@@ -117,10 +119,8 @@ class DiffusionPipeline(ConfigMixin): ...@@ -117,10 +119,8 @@ class DiffusionPipeline(ConfigMixin):
cached_folder = snapshot_download( cached_folder = snapshot_download(
pretrained_model_name_or_path, pretrained_model_name_or_path,
cache_dir=cache_dir, cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download, resume_download=resume_download,
proxies=proxies, proxies=proxies,
output_loading_info=output_loading_info,
local_files_only=local_files_only, local_files_only=local_files_only,
use_auth_token=use_auth_token, use_auth_token=use_auth_token,
) )
...@@ -147,20 +147,14 @@ class DiffusionPipeline(ConfigMixin): ...@@ -147,20 +147,14 @@ class DiffusionPipeline(ConfigMixin):
init_kwargs = {} init_kwargs = {}
# get all importable classes to get the load method name for custom models/components
# here we enforce that custom models/components should always subclass from base classes in tansformers and diffusers
all_importable_classes = {}
for library in LOADABLE_CLASSES:
all_importable_classes.update(LOADABLE_CLASSES[library])
for name, (library_name, class_name) in init_dict.items(): for name, (library_name, class_name) in init_dict.items():
# if the model is not in diffusers or transformers, we need to load it from the hub # if the model is not in diffusers or transformers, we need to load it from the hub
# assumes that it's a subclass of ModelMixin # assumes that it's a subclass of ModelMixin
if library_name == module_candidate_name: if library_name == module_candidate_name:
class_obj = get_class_from_dynamic_module(cached_folder, module, class_name, cached_folder) class_obj = get_class_from_dynamic_module(cached_folder, module, class_name, cached_folder)
# since it's not from a library, we need to check class candidates for all importable classes # since it's not from a library, we need to check class candidates for all importable classes
importable_classes = all_importable_classes importable_classes = ALL_IMPORTABLE_CLASSES
class_candidates = {c: class_obj for c in all_importable_classes} class_candidates = {c: class_obj for c in ALL_IMPORTABLE_CLASSES.keys()}
else: else:
library = importlib.import_module(library_name) library = importlib.import_module(library_name)
class_obj = getattr(library, class_name) class_obj = getattr(library, class_name)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment