Unverified Commit f2756253 authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

Fix a bug in `AutoPipeline.from_pipe` when switching pipeline with optional components (#6820)



* fix

* add tests

---------
Co-authored-by: default avataryiyixuxu <yixu310@gmail,com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 0071478d
......@@ -166,8 +166,7 @@ class IPAdapterMixin:
pretrained_model_name_or_path_or_dict,
subfolder=Path(subfolder, "image_encoder").as_posix(),
).to(self.device, dtype=self.dtype)
self.image_encoder = image_encoder
self.register_to_config(image_encoder=["transformers", "CLIPVisionModelWithProjection"])
self.register_modules(image_encoder=image_encoder)
else:
raise ValueError("`image_encoder` cannot be None when using IP Adapters.")
......
......@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from collections import OrderedDict
from huggingface_hub.utils import validate_hf_hub_args
......@@ -164,14 +163,6 @@ def _get_task_class(mapping, pipeline_class_name, throw_error_if_not_exist: bool
raise ValueError(f"AutoPipeline can't find a pipeline linked to {pipeline_class_name} for {model_name}")
def _get_signature_keys(obj):
parameters = inspect.signature(obj.__init__).parameters
required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
expected_modules = set(required_parameters.keys()) - {"self"}
return expected_modules, optional_parameters
class AutoPipelineForText2Image(ConfigMixin):
r"""
......@@ -391,7 +382,7 @@ class AutoPipelineForText2Image(ConfigMixin):
)
# define expected module and optional kwargs given the pipeline signature
expected_modules, optional_kwargs = _get_signature_keys(text_2_image_cls)
expected_modules, optional_kwargs = text_2_image_cls._get_signature_keys(text_2_image_cls)
pretrained_model_name_or_path = original_config.pop("_name_or_path", None)
......@@ -668,7 +659,7 @@ class AutoPipelineForImage2Image(ConfigMixin):
)
# define expected module and optional kwargs given the pipeline signature
expected_modules, optional_kwargs = _get_signature_keys(image_2_image_cls)
expected_modules, optional_kwargs = image_2_image_cls._get_signature_keys(image_2_image_cls)
pretrained_model_name_or_path = original_config.pop("_name_or_path", None)
......@@ -943,7 +934,7 @@ class AutoPipelineForInpainting(ConfigMixin):
)
# define expected module and optional kwargs given the pipeline signature
expected_modules, optional_kwargs = _get_signature_keys(inpainting_cls)
expected_modules, optional_kwargs = inpainting_cls._get_signature_keys(inpainting_cls)
pretrained_model_name_or_path = original_config.pop("_name_or_path", None)
......
......@@ -21,6 +21,7 @@ from collections import OrderedDict
from pathlib import Path
import torch
from transformers import CLIPVisionConfig, CLIPVisionModelWithProjection
from diffusers import (
AutoPipelineForImage2Image,
......@@ -48,6 +49,20 @@ PRETRAINED_MODEL_REPO_MAPPING = OrderedDict(
class AutoPipelineFastTest(unittest.TestCase):
@property
def dummy_image_encoder(self):
torch.manual_seed(0)
config = CLIPVisionConfig(
hidden_size=1,
projection_dim=1,
num_hidden_layers=1,
num_attention_heads=1,
image_size=1,
intermediate_size=1,
patch_size=1,
)
return CLIPVisionModelWithProjection(config)
def test_from_pipe_consistent(self):
pipe = AutoPipelineForText2Image.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-pipe", requires_safety_checker=False
......@@ -204,6 +219,20 @@ class AutoPipelineFastTest(unittest.TestCase):
assert pipe_control_img2img.__class__.__name__ == "StableDiffusionControlNetImg2ImgPipeline"
assert "controlnet" in pipe_control_img2img.components
def test_from_pipe_optional_components(self):
image_encoder = self.dummy_image_encoder
pipe = AutoPipelineForText2Image.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-pipe",
image_encoder=image_encoder,
)
pipe = AutoPipelineForImage2Image.from_pipe(pipe)
assert pipe.image_encoder is not None
pipe = AutoPipelineForText2Image.from_pipe(pipe, image_encoder=None)
assert pipe.image_encoder is None
@slow
class AutoPipelineIntegrationTest(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment