Unverified Commit f2756253 authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

Fix a bug in `AutoPipeline.from_pipe` when switching pipeline with optional components (#6820)



* fix

* add tests

---------
Co-authored-by: default avataryiyixuxu <yixu310@gmail,com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 0071478d
...@@ -166,8 +166,7 @@ class IPAdapterMixin: ...@@ -166,8 +166,7 @@ class IPAdapterMixin:
pretrained_model_name_or_path_or_dict, pretrained_model_name_or_path_or_dict,
subfolder=Path(subfolder, "image_encoder").as_posix(), subfolder=Path(subfolder, "image_encoder").as_posix(),
).to(self.device, dtype=self.dtype) ).to(self.device, dtype=self.dtype)
self.image_encoder = image_encoder self.register_modules(image_encoder=image_encoder)
self.register_to_config(image_encoder=["transformers", "CLIPVisionModelWithProjection"])
else: else:
raise ValueError("`image_encoder` cannot be None when using IP Adapters.") raise ValueError("`image_encoder` cannot be None when using IP Adapters.")
......
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import inspect
from collections import OrderedDict from collections import OrderedDict
from huggingface_hub.utils import validate_hf_hub_args from huggingface_hub.utils import validate_hf_hub_args
...@@ -164,14 +163,6 @@ def _get_task_class(mapping, pipeline_class_name, throw_error_if_not_exist: bool ...@@ -164,14 +163,6 @@ def _get_task_class(mapping, pipeline_class_name, throw_error_if_not_exist: bool
raise ValueError(f"AutoPipeline can't find a pipeline linked to {pipeline_class_name} for {model_name}") raise ValueError(f"AutoPipeline can't find a pipeline linked to {pipeline_class_name} for {model_name}")
def _get_signature_keys(obj):
parameters = inspect.signature(obj.__init__).parameters
required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
expected_modules = set(required_parameters.keys()) - {"self"}
return expected_modules, optional_parameters
class AutoPipelineForText2Image(ConfigMixin): class AutoPipelineForText2Image(ConfigMixin):
r""" r"""
...@@ -391,7 +382,7 @@ class AutoPipelineForText2Image(ConfigMixin): ...@@ -391,7 +382,7 @@ class AutoPipelineForText2Image(ConfigMixin):
) )
# define expected module and optional kwargs given the pipeline signature # define expected module and optional kwargs given the pipeline signature
expected_modules, optional_kwargs = _get_signature_keys(text_2_image_cls) expected_modules, optional_kwargs = text_2_image_cls._get_signature_keys(text_2_image_cls)
pretrained_model_name_or_path = original_config.pop("_name_or_path", None) pretrained_model_name_or_path = original_config.pop("_name_or_path", None)
...@@ -668,7 +659,7 @@ class AutoPipelineForImage2Image(ConfigMixin): ...@@ -668,7 +659,7 @@ class AutoPipelineForImage2Image(ConfigMixin):
) )
# define expected module and optional kwargs given the pipeline signature # define expected module and optional kwargs given the pipeline signature
expected_modules, optional_kwargs = _get_signature_keys(image_2_image_cls) expected_modules, optional_kwargs = image_2_image_cls._get_signature_keys(image_2_image_cls)
pretrained_model_name_or_path = original_config.pop("_name_or_path", None) pretrained_model_name_or_path = original_config.pop("_name_or_path", None)
...@@ -943,7 +934,7 @@ class AutoPipelineForInpainting(ConfigMixin): ...@@ -943,7 +934,7 @@ class AutoPipelineForInpainting(ConfigMixin):
) )
# define expected module and optional kwargs given the pipeline signature # define expected module and optional kwargs given the pipeline signature
expected_modules, optional_kwargs = _get_signature_keys(inpainting_cls) expected_modules, optional_kwargs = inpainting_cls._get_signature_keys(inpainting_cls)
pretrained_model_name_or_path = original_config.pop("_name_or_path", None) pretrained_model_name_or_path = original_config.pop("_name_or_path", None)
......
...@@ -21,6 +21,7 @@ from collections import OrderedDict ...@@ -21,6 +21,7 @@ from collections import OrderedDict
from pathlib import Path from pathlib import Path
import torch import torch
from transformers import CLIPVisionConfig, CLIPVisionModelWithProjection
from diffusers import ( from diffusers import (
AutoPipelineForImage2Image, AutoPipelineForImage2Image,
...@@ -48,6 +49,20 @@ PRETRAINED_MODEL_REPO_MAPPING = OrderedDict( ...@@ -48,6 +49,20 @@ PRETRAINED_MODEL_REPO_MAPPING = OrderedDict(
class AutoPipelineFastTest(unittest.TestCase): class AutoPipelineFastTest(unittest.TestCase):
@property
def dummy_image_encoder(self):
torch.manual_seed(0)
config = CLIPVisionConfig(
hidden_size=1,
projection_dim=1,
num_hidden_layers=1,
num_attention_heads=1,
image_size=1,
intermediate_size=1,
patch_size=1,
)
return CLIPVisionModelWithProjection(config)
def test_from_pipe_consistent(self): def test_from_pipe_consistent(self):
pipe = AutoPipelineForText2Image.from_pretrained( pipe = AutoPipelineForText2Image.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-pipe", requires_safety_checker=False "hf-internal-testing/tiny-stable-diffusion-pipe", requires_safety_checker=False
...@@ -204,6 +219,20 @@ class AutoPipelineFastTest(unittest.TestCase): ...@@ -204,6 +219,20 @@ class AutoPipelineFastTest(unittest.TestCase):
assert pipe_control_img2img.__class__.__name__ == "StableDiffusionControlNetImg2ImgPipeline" assert pipe_control_img2img.__class__.__name__ == "StableDiffusionControlNetImg2ImgPipeline"
assert "controlnet" in pipe_control_img2img.components assert "controlnet" in pipe_control_img2img.components
def test_from_pipe_optional_components(self):
image_encoder = self.dummy_image_encoder
pipe = AutoPipelineForText2Image.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-pipe",
image_encoder=image_encoder,
)
pipe = AutoPipelineForImage2Image.from_pipe(pipe)
assert pipe.image_encoder is not None
pipe = AutoPipelineForText2Image.from_pipe(pipe, image_encoder=None)
assert pipe.image_encoder is None
@slow @slow
class AutoPipelineIntegrationTest(unittest.TestCase): class AutoPipelineIntegrationTest(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment