Unverified Commit 513f1fbf authored by Pedro Cuenca's avatar Pedro Cuenca Committed by GitHub
Browse files

Allow passing non-default modules to pipeline (#188)



* Allow passing non-default modules to pipeline.

Override modules are recognized and replaced in the pipeline. However,
no check is performed about mismatched classes yet. This is because the
override module is already instantiated and we have no library or class
name to compare against.

* up

* add test
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent d7b69208
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
# limitations under the License. # limitations under the License.
import importlib import importlib
import inspect
import os import os
from typing import Optional, Union from typing import Optional, Union
...@@ -148,6 +149,12 @@ class DiffusionPipeline(ConfigMixin): ...@@ -148,6 +149,12 @@ class DiffusionPipeline(ConfigMixin):
diffusers_module = importlib.import_module(cls.__module__.split(".")[0]) diffusers_module = importlib.import_module(cls.__module__.split(".")[0])
pipeline_class = getattr(diffusers_module, config_dict["_class_name"]) pipeline_class = getattr(diffusers_module, config_dict["_class_name"])
# some modules can be passed directly to the init
# in this case they are already instantiated in `kwargs`
# extract them here
expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys())
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
init_kwargs = {} init_kwargs = {}
...@@ -158,8 +165,36 @@ class DiffusionPipeline(ConfigMixin): ...@@ -158,8 +165,36 @@ class DiffusionPipeline(ConfigMixin):
# 3. Load each module in the pipeline # 3. Load each module in the pipeline
for name, (library_name, class_name) in init_dict.items(): for name, (library_name, class_name) in init_dict.items():
is_pipeline_module = hasattr(pipelines, library_name) is_pipeline_module = hasattr(pipelines, library_name)
loaded_sub_model = None
# if the model is in a pipeline module, then we load it from the pipeline # if the model is in a pipeline module, then we load it from the pipeline
if is_pipeline_module: if name in passed_class_obj:
# 1. check that passed_class_obj has correct parent class
if not is_pipeline_module:
library = importlib.import_module(library_name)
class_obj = getattr(library, class_name)
importable_classes = LOADABLE_CLASSES[library_name]
class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}
expected_class_obj = None
for class_name, class_candidate in class_candidates.items():
if issubclass(class_obj, class_candidate):
expected_class_obj = class_candidate
if not issubclass(passed_class_obj[name].__class__, expected_class_obj):
raise ValueError(
f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be"
f" {expected_class_obj}"
)
else:
logger.warn(
f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
" has the correct type"
)
# set passed class object
loaded_sub_model = passed_class_obj[name]
elif is_pipeline_module:
pipeline_module = getattr(pipelines, library_name) pipeline_module = getattr(pipelines, library_name)
class_obj = getattr(pipeline_module, class_name) class_obj = getattr(pipeline_module, class_name)
importable_classes = ALL_IMPORTABLE_CLASSES importable_classes = ALL_IMPORTABLE_CLASSES
...@@ -171,6 +206,7 @@ class DiffusionPipeline(ConfigMixin): ...@@ -171,6 +206,7 @@ class DiffusionPipeline(ConfigMixin):
importable_classes = LOADABLE_CLASSES[library_name] importable_classes = LOADABLE_CLASSES[library_name]
class_candidates = {c: getattr(library, c) for c in importable_classes.keys()} class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}
if loaded_sub_model is None:
load_method_name = None load_method_name = None
for class_name, class_candidate in class_candidates.items(): for class_name, class_candidate in class_candidates.items():
if issubclass(class_obj, class_candidate): if issubclass(class_obj, class_candidate):
...@@ -187,7 +223,7 @@ class DiffusionPipeline(ConfigMixin): ...@@ -187,7 +223,7 @@ class DiffusionPipeline(ConfigMixin):
init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...) init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
# 5. Instantiate the pipeline # 4. Instantiate the pipeline
model = pipeline_class(**init_kwargs) model = pipeline_class(**init_kwargs)
return model return model
......
...@@ -718,6 +718,28 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -718,6 +718,28 @@ class PipelineTesterMixin(unittest.TestCase):
assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass" assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"
@slow
def test_from_pretrained_hub_pass_model(self):
model_path = "google/ddpm-cifar10-32"
# pass unet into DiffusionPipeline
unet = UNet2DModel.from_pretrained(model_path)
ddpm_from_hub_custom_model = DDPMPipeline.from_pretrained(model_path, unet=unet)
ddpm_from_hub_custom_model = DiffusionPipeline.from_pretrained(model_path, unet=unet)
ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path)
ddpm_from_hub_custom_model.scheduler.num_timesteps = 10
ddpm_from_hub.scheduler.num_timesteps = 10
generator = torch.manual_seed(0)
image = ddpm_from_hub_custom_model(generator=generator, output_type="numpy")["sample"]
generator = generator.manual_seed(0)
new_image = ddpm_from_hub(generator=generator, output_type="numpy")["sample"]
assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"
@slow @slow
def test_output_format(self): def test_output_format(self):
model_path = "google/ddpm-cifar10-32" model_path = "google/ddpm-cifar10-32"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment