".github/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "c646fbc1247e444c18da3db5c7ffb1438ca05394"
Unverified Commit 2c1677ee authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

allow passing components to connected pipelines when use the combined pipeline (#4883)



* fix

* add test

---------
Co-authored-by: default avataryiyixuxu <yixu310@gmail,com>
parent c73e609a
...@@ -1147,8 +1147,22 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -1147,8 +1147,22 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
"variant": variant, "variant": variant,
"use_safetensors": use_safetensors, "use_safetensors": use_safetensors,
} }
def get_connected_passed_kwargs(prefix):
connected_passed_class_obj = {
k.replace(f"{prefix}_", ""): w for k, w in passed_class_obj.items() if k.split("_")[0] == prefix
}
connected_passed_pipe_kwargs = {
k.replace(f"{prefix}_", ""): w for k, w in passed_pipe_kwargs.items() if k.split("_")[0] == prefix
}
connected_passed_kwargs = {**connected_passed_class_obj, **connected_passed_pipe_kwargs}
return connected_passed_kwargs
connected_pipes = { connected_pipes = {
prefix: DiffusionPipeline.from_pretrained(repo_id, **load_kwargs.copy()) prefix: DiffusionPipeline.from_pretrained(
repo_id, **load_kwargs.copy(), **get_connected_passed_kwargs(prefix)
)
for prefix, repo_id in connected_pipes.items() for prefix, repo_id in connected_pipes.items()
if repo_id is not None if repo_id is not None
} }
......
...@@ -18,7 +18,13 @@ import unittest ...@@ -18,7 +18,13 @@ import unittest
import torch import torch
from huggingface_hub import ModelCard from huggingface_hub import ModelCard
from diffusers import DiffusionPipeline, KandinskyV22CombinedPipeline, KandinskyV22Pipeline, KandinskyV22PriorPipeline from diffusers import (
DDPMScheduler,
DiffusionPipeline,
KandinskyV22CombinedPipeline,
KandinskyV22Pipeline,
KandinskyV22PriorPipeline,
)
from diffusers.pipelines.pipeline_utils import CONNECTED_PIPES_KEYS from diffusers.pipelines.pipeline_utils import CONNECTED_PIPES_KEYS
...@@ -101,3 +107,22 @@ class CombinedPipelineFastTest(unittest.TestCase): ...@@ -101,3 +107,22 @@ class CombinedPipelineFastTest(unittest.TestCase):
assert dict(component.config) == dict(comp.config) assert dict(component.config) == dict(comp.config)
else: else:
assert component.__class__ == comp.__class__ assert component.__class__ == comp.__class__
def test_load_connected_checkpoint_with_passed_obj(self):
pipeline = KandinskyV22CombinedPipeline.from_pretrained(
"hf-internal-testing/tiny-random-kandinsky-v22-decoder"
)
prior_scheduler = DDPMScheduler.from_config(pipeline.prior_scheduler.config)
scheduler = DDPMScheduler.from_config(pipeline.scheduler.config)
# make sure we pass a different scheduler and prior_scheduler
assert pipeline.prior_scheduler.__class__ != prior_scheduler.__class__
assert pipeline.scheduler.__class__ != scheduler.__class__
pipeline_new = KandinskyV22CombinedPipeline.from_pretrained(
"hf-internal-testing/tiny-random-kandinsky-v22-decoder",
prior_scheduler=prior_scheduler,
scheduler=scheduler,
)
assert dict(pipeline_new.prior_scheduler.config) == dict(prior_scheduler.config)
assert dict(pipeline_new.scheduler.config) == dict(scheduler.config)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment