Unverified Commit 2c1677ee authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

allow passing components to connected pipelines when use the combined pipeline (#4883)



* fix

* add test

---------
Co-authored-by: default avataryiyixuxu <yixu310@gmail,com>
parent c73e609a
...@@ -1147,8 +1147,22 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -1147,8 +1147,22 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
"variant": variant, "variant": variant,
"use_safetensors": use_safetensors, "use_safetensors": use_safetensors,
} }
def get_connected_passed_kwargs(prefix):
connected_passed_class_obj = {
k.replace(f"{prefix}_", ""): w for k, w in passed_class_obj.items() if k.split("_")[0] == prefix
}
connected_passed_pipe_kwargs = {
k.replace(f"{prefix}_", ""): w for k, w in passed_pipe_kwargs.items() if k.split("_")[0] == prefix
}
connected_passed_kwargs = {**connected_passed_class_obj, **connected_passed_pipe_kwargs}
return connected_passed_kwargs
connected_pipes = { connected_pipes = {
prefix: DiffusionPipeline.from_pretrained(repo_id, **load_kwargs.copy()) prefix: DiffusionPipeline.from_pretrained(
repo_id, **load_kwargs.copy(), **get_connected_passed_kwargs(prefix)
)
for prefix, repo_id in connected_pipes.items() for prefix, repo_id in connected_pipes.items()
if repo_id is not None if repo_id is not None
} }
......
...@@ -18,7 +18,13 @@ import unittest ...@@ -18,7 +18,13 @@ import unittest
import torch import torch
from huggingface_hub import ModelCard from huggingface_hub import ModelCard
from diffusers import DiffusionPipeline, KandinskyV22CombinedPipeline, KandinskyV22Pipeline, KandinskyV22PriorPipeline from diffusers import (
DDPMScheduler,
DiffusionPipeline,
KandinskyV22CombinedPipeline,
KandinskyV22Pipeline,
KandinskyV22PriorPipeline,
)
from diffusers.pipelines.pipeline_utils import CONNECTED_PIPES_KEYS from diffusers.pipelines.pipeline_utils import CONNECTED_PIPES_KEYS
...@@ -101,3 +107,22 @@ class CombinedPipelineFastTest(unittest.TestCase): ...@@ -101,3 +107,22 @@ class CombinedPipelineFastTest(unittest.TestCase):
assert dict(component.config) == dict(comp.config) assert dict(component.config) == dict(comp.config)
else: else:
assert component.__class__ == comp.__class__ assert component.__class__ == comp.__class__
def test_load_connected_checkpoint_with_passed_obj(self):
pipeline = KandinskyV22CombinedPipeline.from_pretrained(
"hf-internal-testing/tiny-random-kandinsky-v22-decoder"
)
prior_scheduler = DDPMScheduler.from_config(pipeline.prior_scheduler.config)
scheduler = DDPMScheduler.from_config(pipeline.scheduler.config)
# make sure we pass a different scheduler and prior_scheduler
assert pipeline.prior_scheduler.__class__ != prior_scheduler.__class__
assert pipeline.scheduler.__class__ != scheduler.__class__
pipeline_new = KandinskyV22CombinedPipeline.from_pretrained(
"hf-internal-testing/tiny-random-kandinsky-v22-decoder",
prior_scheduler=prior_scheduler,
scheduler=scheduler,
)
assert dict(pipeline_new.prior_scheduler.config) == dict(prior_scheduler.config)
assert dict(pipeline_new.scheduler.config) == dict(scheduler.config)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment