Unverified Commit d4197bf4 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Allow custom pipeline loading (#3504)

parent b134f6a8
...@@ -491,15 +491,19 @@ class DiffusionPipeline(ConfigMixin): ...@@ -491,15 +491,19 @@ class DiffusionPipeline(ConfigMixin):
library = module.__module__.split(".")[0] library = module.__module__.split(".")[0]
# check if the module is a pipeline module # check if the module is a pipeline module
pipeline_dir = module.__module__.split(".")[-2] if len(module.__module__.split(".")) > 2 else None module_path_items = module.__module__.split(".")
pipeline_dir = module_path_items[-2] if len(module_path_items) > 2 else None
path = module.__module__.split(".") path = module.__module__.split(".")
is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir) is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)
# if library is not in LOADABLE_CLASSES, then it is a custom module. # if library is not in LOADABLE_CLASSES, then it is a custom module.
# Or if it's a pipeline module, then the module is inside the pipeline # Or if it's a pipeline module, then the module is inside the pipeline
# folder so we set the library to module name. # folder so we set the library to module name.
if library not in LOADABLE_CLASSES or is_pipeline_module: if is_pipeline_module:
library = pipeline_dir library = pipeline_dir
elif library not in LOADABLE_CLASSES:
library = module.__module__
# retrieve class_name # retrieve class_name
class_name = module.__class__.__name__ class_name = module.__class__.__name__
...@@ -1039,7 +1043,7 @@ class DiffusionPipeline(ConfigMixin): ...@@ -1039,7 +1043,7 @@ class DiffusionPipeline(ConfigMixin):
# 6.2 Define all importable classes # 6.2 Define all importable classes
is_pipeline_module = hasattr(pipelines, library_name) is_pipeline_module = hasattr(pipelines, library_name)
importable_classes = ALL_IMPORTABLE_CLASSES if is_pipeline_module else LOADABLE_CLASSES[library_name] importable_classes = ALL_IMPORTABLE_CLASSES
loaded_sub_model = None loaded_sub_model = None
# 6.3 Use passed sub model or load class_name from library_name # 6.3 Use passed sub model or load class_name from library_name
......
...@@ -35,6 +35,7 @@ from transformers import CLIPImageProcessor, CLIPModel, CLIPTextConfig, CLIPText ...@@ -35,6 +35,7 @@ from transformers import CLIPImageProcessor, CLIPModel, CLIPTextConfig, CLIPText
from diffusers import ( from diffusers import (
AutoencoderKL, AutoencoderKL,
ConfigMixin,
DDIMPipeline, DDIMPipeline,
DDIMScheduler, DDIMScheduler,
DDPMPipeline, DDPMPipeline,
...@@ -44,6 +45,7 @@ from diffusers import ( ...@@ -44,6 +45,7 @@ from diffusers import (
EulerAncestralDiscreteScheduler, EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler, EulerDiscreteScheduler,
LMSDiscreteScheduler, LMSDiscreteScheduler,
ModelMixin,
PNDMScheduler, PNDMScheduler,
StableDiffusionImg2ImgPipeline, StableDiffusionImg2ImgPipeline,
StableDiffusionInpaintPipelineLegacy, StableDiffusionInpaintPipelineLegacy,
...@@ -77,6 +79,17 @@ from diffusers.utils.testing_utils import ( ...@@ -77,6 +79,17 @@ from diffusers.utils.testing_utils import (
enable_full_determinism() enable_full_determinism()
class CustomEncoder(ModelMixin, ConfigMixin):
def __init__(self):
super().__init__()
class CustomPipeline(DiffusionPipeline):
def __init__(self, encoder: CustomEncoder, scheduler: DDIMScheduler):
super().__init__()
self.register_modules(encoder=encoder, scheduler=scheduler)
class DownloadTests(unittest.TestCase): class DownloadTests(unittest.TestCase):
def test_one_request_upon_cached(self): def test_one_request_upon_cached(self):
# TODO: For some reason this test fails on MPS where no HEAD call is made. # TODO: For some reason this test fails on MPS where no HEAD call is made.
...@@ -695,6 +708,20 @@ class CustomPipelineTests(unittest.TestCase): ...@@ -695,6 +708,20 @@ class CustomPipelineTests(unittest.TestCase):
# compare to https://github.com/huggingface/diffusers/blob/main/tests/fixtures/custom_pipeline/pipeline.py#L102 # compare to https://github.com/huggingface/diffusers/blob/main/tests/fixtures/custom_pipeline/pipeline.py#L102
assert output_str == "This is a local test" assert output_str == "This is a local test"
def test_custom_model_and_pipeline(self):
pipe = CustomPipeline(
encoder=CustomEncoder(),
scheduler=DDIMScheduler(),
)
with tempfile.TemporaryDirectory() as tmpdirname:
pipe.save_pretrained(tmpdirname)
pipe_new = CustomPipeline.from_pretrained(tmpdirname)
pipe_new.save_pretrained(tmpdirname)
assert dict(pipe_new.config) == dict(pipe.config)
@slow @slow
@require_torch_gpu @require_torch_gpu
def test_download_from_git(self): def test_download_from_git(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment