"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "5d476f57c58c3cf7f39e764236c93c267fe83ca1"
Unverified Commit ba352aea authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

[feat] IP Adapters (author @okotaku ) (#5713)



* add ip-adapter


---------
Co-authored-by: default avatarokotaku <to78314910@gmail.com>
Co-authored-by: default avatarsayakpaul <spsayakpaul@gmail.com>
Co-authored-by: default avataryiyixuxu <yixu310@gmail,com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>
parent 6fac1369
...@@ -1221,6 +1221,15 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -1221,6 +1221,15 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
) )
image_embeds = added_cond_kwargs.get("image_embeds") image_embeds = added_cond_kwargs.get("image_embeds")
encoder_hidden_states = self.encoder_hid_proj(image_embeds) encoder_hidden_states = self.encoder_hid_proj(image_embeds)
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
if "image_embeds" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
)
image_embeds = added_cond_kwargs.get("image_embeds")
image_embeds = self.encoder_hid_proj(image_embeds).to(encoder_hidden_states.dtype)
encoder_hidden_states = torch.cat([encoder_hidden_states, image_embeds], dim=1)
# 2. pre-process # 2. pre-process
sample = self.conv_in(sample) sample = self.conv_in(sample)
......
...@@ -246,6 +246,7 @@ class LoraLoaderMixinTests(unittest.TestCase): ...@@ -246,6 +246,7 @@ class LoraLoaderMixinTests(unittest.TestCase):
"tokenizer": tokenizer, "tokenizer": tokenizer,
"safety_checker": None, "safety_checker": None,
"feature_extractor": None, "feature_extractor": None,
"image_encoder": None,
} }
lora_components = { lora_components = {
"unet_lora_layers": unet_lora_layers, "unet_lora_layers": unet_lora_layers,
...@@ -757,6 +758,7 @@ class SDXInpaintLoraMixinTests(unittest.TestCase): ...@@ -757,6 +758,7 @@ class SDXInpaintLoraMixinTests(unittest.TestCase):
"tokenizer": tokenizer, "tokenizer": tokenizer,
"safety_checker": None, "safety_checker": None,
"feature_extractor": None, "feature_extractor": None,
"image_encoder": None,
} }
return components return components
...@@ -866,6 +868,8 @@ class SDXLLoraLoaderMixinTests(unittest.TestCase): ...@@ -866,6 +868,8 @@ class SDXLLoraLoaderMixinTests(unittest.TestCase):
"text_encoder_2": text_encoder_2, "text_encoder_2": text_encoder_2,
"tokenizer": tokenizer, "tokenizer": tokenizer,
"tokenizer_2": tokenizer_2, "tokenizer_2": tokenizer_2,
"image_encoder": None,
"feature_extractor": None,
} }
lora_components = { lora_components = {
"unet_lora_layers": unet_lora_layers, "unet_lora_layers": unet_lora_layers,
......
...@@ -140,6 +140,8 @@ class PeftLoraLoaderMixinTests: ...@@ -140,6 +140,8 @@ class PeftLoraLoaderMixinTests:
"tokenizer": tokenizer, "tokenizer": tokenizer,
"text_encoder_2": text_encoder_2, "text_encoder_2": text_encoder_2,
"tokenizer_2": tokenizer_2, "tokenizer_2": tokenizer_2,
"image_encoder": None,
"feature_extractor": None,
} }
else: else:
pipeline_components = { pipeline_components = {
...@@ -150,6 +152,7 @@ class PeftLoraLoaderMixinTests: ...@@ -150,6 +152,7 @@ class PeftLoraLoaderMixinTests:
"tokenizer": tokenizer, "tokenizer": tokenizer,
"safety_checker": None, "safety_checker": None,
"feature_extractor": None, "feature_extractor": None,
"image_encoder": None,
} }
lora_components = { lora_components = {
"unet_lora_layers": unet_lora_layers, "unet_lora_layers": unet_lora_layers,
......
...@@ -24,7 +24,8 @@ from parameterized import parameterized ...@@ -24,7 +24,8 @@ from parameterized import parameterized
from pytest import mark from pytest import mark
from diffusers import UNet2DConditionModel from diffusers import UNet2DConditionModel
from diffusers.models.attention_processor import CustomDiffusionAttnProcessor from diffusers.models.attention_processor import CustomDiffusionAttnProcessor, IPAdapterAttnProcessor
from diffusers.models.embeddings import ImageProjection
from diffusers.utils import logging from diffusers.utils import logging
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
...@@ -45,6 +46,57 @@ logger = logging.get_logger(__name__) ...@@ -45,6 +46,57 @@ logger = logging.get_logger(__name__)
enable_full_determinism() enable_full_determinism()
def create_ip_adapter_state_dict(model):
# "ip_adapter" (cross-attention weights)
ip_cross_attn_state_dict = {}
key_id = 1
for name in model.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = model.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(model.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = model.config.block_out_channels[block_id]
if cross_attention_dim is not None:
sd = IPAdapterAttnProcessor(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0
).state_dict()
ip_cross_attn_state_dict.update(
{
f"{key_id}.to_k_ip.weight": sd["to_k_ip.weight"],
f"{key_id}.to_v_ip.weight": sd["to_v_ip.weight"],
}
)
key_id += 2
# "image_proj" (ImageProjection layer weights)
cross_attention_dim = model.config["cross_attention_dim"]
image_projection = ImageProjection(
cross_attention_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, num_image_text_embeds=4
)
ip_image_projection_state_dict = {}
sd = image_projection.state_dict()
ip_image_projection_state_dict.update(
{
"proj.weight": sd["image_embeds.weight"],
"proj.bias": sd["image_embeds.bias"],
"norm.weight": sd["norm.weight"],
"norm.bias": sd["norm.bias"],
}
)
del sd
ip_state_dict = {}
ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict})
return ip_state_dict
def create_custom_diffusion_layers(model, mock_weights: bool = True): def create_custom_diffusion_layers(model, mock_weights: bool = True):
train_kv = True train_kv = True
train_q_out = True train_q_out = True
...@@ -622,6 +674,56 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test ...@@ -622,6 +674,56 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
# Check if input and output shapes are the same # Check if input and output shapes are the same
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_ip_adapter(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["attention_head_dim"] = (8, 16)
model = self.model_class(**init_dict)
model.to(torch_device)
# forward pass without ip-adapter
with torch.no_grad():
sample1 = model(**inputs_dict).sample
# update inputs_dict for ip-adapter
batch_size = inputs_dict["encoder_hidden_states"].shape[0]
image_embeds = floats_tensor((batch_size, 1, model.cross_attention_dim)).to(torch_device)
inputs_dict["added_cond_kwargs"] = {"image_embeds": image_embeds}
# make ip_adapter_1 and ip_adapter_2
ip_adapter_1 = create_ip_adapter_state_dict(model)
image_proj_state_dict_2 = {k: w + 1.0 for k, w in ip_adapter_1["image_proj"].items()}
cross_attn_state_dict_2 = {k: w + 1.0 for k, w in ip_adapter_1["ip_adapter"].items()}
ip_adapter_2 = {}
ip_adapter_2.update({"image_proj": image_proj_state_dict_2, "ip_adapter": cross_attn_state_dict_2})
# forward pass ip_adapter_1
model._load_ip_adapter_weights(ip_adapter_1)
assert model.config.encoder_hid_dim_type == "ip_image_proj"
assert model.encoder_hid_proj is not None
assert model.down_blocks[0].attentions[0].transformer_blocks[0].attn2.processor.__class__.__name__ in (
"IPAdapterAttnProcessor",
"IPAdapterAttnProcessor2_0",
)
with torch.no_grad():
sample2 = model(**inputs_dict).sample
# forward pass with ip_adapter_2
model._load_ip_adapter_weights(ip_adapter_2)
with torch.no_grad():
sample3 = model(**inputs_dict).sample
# forward pass with ip_adapter_1 again
model._load_ip_adapter_weights(ip_adapter_1)
with torch.no_grad():
sample4 = model(**inputs_dict).sample
assert not sample1.allclose(sample2, atol=1e-4, rtol=1e-4)
assert not sample2.allclose(sample3, atol=1e-4, rtol=1e-4)
assert sample2.allclose(sample4, atol=1e-4, rtol=1e-4)
@slow @slow
class UNet2DConditionModelIntegrationTests(unittest.TestCase): class UNet2DConditionModelIntegrationTests(unittest.TestCase):
......
...@@ -117,6 +117,7 @@ class AltDiffusionPipelineFastTests( ...@@ -117,6 +117,7 @@ class AltDiffusionPipelineFastTests(
"tokenizer": tokenizer, "tokenizer": tokenizer,
"safety_checker": None, "safety_checker": None,
"feature_extractor": None, "feature_extractor": None,
"image_encoder": None,
} }
return components return components
......
...@@ -141,6 +141,7 @@ class AltDiffusionImg2ImgPipelineFastTests(unittest.TestCase): ...@@ -141,6 +141,7 @@ class AltDiffusionImg2ImgPipelineFastTests(unittest.TestCase):
tokenizer=tokenizer, tokenizer=tokenizer,
safety_checker=None, safety_checker=None,
feature_extractor=self.dummy_extractor, feature_extractor=self.dummy_extractor,
image_encoder=None,
) )
alt_pipe.image_processor = VaeImageProcessor(vae_scale_factor=alt_pipe.vae_scale_factor, do_normalize=True) alt_pipe.image_processor = VaeImageProcessor(vae_scale_factor=alt_pipe.vae_scale_factor, do_normalize=True)
alt_pipe = alt_pipe.to(device) alt_pipe = alt_pipe.to(device)
...@@ -205,6 +206,7 @@ class AltDiffusionImg2ImgPipelineFastTests(unittest.TestCase): ...@@ -205,6 +206,7 @@ class AltDiffusionImg2ImgPipelineFastTests(unittest.TestCase):
tokenizer=tokenizer, tokenizer=tokenizer,
safety_checker=None, safety_checker=None,
feature_extractor=self.dummy_extractor, feature_extractor=self.dummy_extractor,
image_encoder=None,
) )
alt_pipe.image_processor = VaeImageProcessor(vae_scale_factor=alt_pipe.vae_scale_factor, do_normalize=False) alt_pipe.image_processor = VaeImageProcessor(vae_scale_factor=alt_pipe.vae_scale_factor, do_normalize=False)
alt_pipe = alt_pipe.to(torch_device) alt_pipe = alt_pipe.to(torch_device)
......
...@@ -99,6 +99,8 @@ class AnimateDiffPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -99,6 +99,8 @@ class AnimateDiffPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
"motion_adapter": motion_adapter, "motion_adapter": motion_adapter,
"text_encoder": text_encoder, "text_encoder": text_encoder,
"tokenizer": tokenizer, "tokenizer": tokenizer,
"feature_extractor": None,
"image_encoder": None,
} }
return components return components
......
...@@ -183,6 +183,7 @@ class ControlNetPipelineFastTests( ...@@ -183,6 +183,7 @@ class ControlNetPipelineFastTests(
"tokenizer": tokenizer, "tokenizer": tokenizer,
"safety_checker": None, "safety_checker": None,
"feature_extractor": None, "feature_extractor": None,
"image_encoder": None,
} }
return components return components
...@@ -341,6 +342,7 @@ class StableDiffusionMultiControlNetPipelineFastTests( ...@@ -341,6 +342,7 @@ class StableDiffusionMultiControlNetPipelineFastTests(
"tokenizer": tokenizer, "tokenizer": tokenizer,
"safety_checker": None, "safety_checker": None,
"feature_extractor": None, "feature_extractor": None,
"image_encoder": None,
} }
return components return components
...@@ -518,6 +520,7 @@ class StableDiffusionMultiControlNetOneModelPipelineFastTests( ...@@ -518,6 +520,7 @@ class StableDiffusionMultiControlNetOneModelPipelineFastTests(
"tokenizer": tokenizer, "tokenizer": tokenizer,
"safety_checker": None, "safety_checker": None,
"feature_extractor": None, "feature_extractor": None,
"image_encoder": None,
} }
return components return components
......
...@@ -146,6 +146,8 @@ class StableDiffusionXLControlNetPipelineFastTests( ...@@ -146,6 +146,8 @@ class StableDiffusionXLControlNetPipelineFastTests(
"tokenizer": tokenizer, "tokenizer": tokenizer,
"text_encoder_2": text_encoder_2, "text_encoder_2": text_encoder_2,
"tokenizer_2": tokenizer_2, "tokenizer_2": tokenizer_2,
"feature_extractor": None,
"image_encoder": None,
} }
return components return components
...@@ -471,6 +473,8 @@ class StableDiffusionXLMultiControlNetPipelineFastTests( ...@@ -471,6 +473,8 @@ class StableDiffusionXLMultiControlNetPipelineFastTests(
"tokenizer": tokenizer, "tokenizer": tokenizer,
"text_encoder_2": text_encoder_2, "text_encoder_2": text_encoder_2,
"tokenizer_2": tokenizer_2, "tokenizer_2": tokenizer_2,
"feature_extractor": None,
"image_encoder": None,
} }
return components return components
...@@ -656,6 +660,8 @@ class StableDiffusionXLMultiControlNetOneModelPipelineFastTests( ...@@ -656,6 +660,8 @@ class StableDiffusionXLMultiControlNetOneModelPipelineFastTests(
"tokenizer": tokenizer, "tokenizer": tokenizer,
"text_encoder_2": text_encoder_2, "text_encoder_2": text_encoder_2,
"tokenizer_2": tokenizer_2, "tokenizer_2": tokenizer_2,
"feature_extractor": None,
"image_encoder": None,
} }
return components return components
......
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import unittest
import numpy as np
import torch
from transformers import (
CLIPImageProcessor,
CLIPVisionModelWithProjection,
)
from diffusers import (
StableDiffusionImg2ImgPipeline,
StableDiffusionInpaintPipeline,
StableDiffusionPipeline,
StableDiffusionXLImg2ImgPipeline,
StableDiffusionXLInpaintPipeline,
StableDiffusionXLPipeline,
)
from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
enable_full_determinism,
require_torch_gpu,
slow,
torch_device,
)
enable_full_determinism()
class IPAdapterNightlyTestsMixin(unittest.TestCase):
dtype = torch.float16
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def get_image_encoder(self, repo_id, subfolder):
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
repo_id, subfolder=subfolder, torch_dtype=self.dtype
).to(torch_device)
return image_encoder
def get_image_processor(self, repo_id):
image_processor = CLIPImageProcessor.from_pretrained(repo_id)
return image_processor
def get_dummy_inputs(self, for_image_to_image=False, for_inpainting=False, for_sdxl=False):
image = load_image(
"https://user-images.githubusercontent.com/24734142/266492875-2d50d223-8475-44f0-a7c6-08b51cb53572.png"
)
if for_sdxl:
image = image.resize((1024, 1024))
input_kwargs = {
"prompt": "best quality, high quality",
"negative_prompt": "monochrome, lowres, bad anatomy, worst quality, low quality",
"num_inference_steps": 5,
"generator": torch.Generator(device="cpu").manual_seed(33),
"ip_adapter_image": image,
"output_type": "np",
}
if for_image_to_image:
image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/vermeer.jpg")
ip_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/river.png")
if for_sdxl:
image = image.resize((1024, 1024))
ip_image = ip_image.resize((1024, 1024))
input_kwargs.update({"image": image, "ip_adapter_image": ip_image})
elif for_inpainting:
image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/inpaint_image.png")
mask = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/mask.png")
ip_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/girl.png")
if for_sdxl:
image = image.resize((1024, 1024))
mask = mask.resize((1024, 1024))
ip_image = ip_image.resize((1024, 1024))
input_kwargs.update({"image": image, "mask_image": mask, "ip_adapter_image": ip_image})
return input_kwargs
@slow
@require_torch_gpu
class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin):
def test_text_to_image(self):
image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder")
pipeline = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", image_encoder=image_encoder, safety_checker=None, torch_dtype=self.dtype
)
pipeline.to(torch_device)
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")
inputs = self.get_dummy_inputs()
images = pipeline(**inputs).images
image_slice = images[0, :3, :3, -1].flatten()
expected_slice = np.array([0.8047, 0.8774, 0.9248, 0.9155, 0.9814, 1.0, 0.9678, 1.0, 1.0])
assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4)
def test_image_to_image(self):
image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder")
pipeline = StableDiffusionImg2ImgPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", image_encoder=image_encoder, safety_checker=None, torch_dtype=self.dtype
)
pipeline.to(torch_device)
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")
inputs = self.get_dummy_inputs(for_image_to_image=True)
images = pipeline(**inputs).images
image_slice = images[0, :3, :3, -1].flatten()
expected_slice = np.array([0.2307, 0.2341, 0.2305, 0.24, 0.2268, 0.25, 0.2322, 0.2588, 0.2935])
assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4)
def test_inpainting(self):
image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder")
pipeline = StableDiffusionInpaintPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", image_encoder=image_encoder, safety_checker=None, torch_dtype=self.dtype
)
pipeline.to(torch_device)
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")
inputs = self.get_dummy_inputs(for_inpainting=True)
images = pipeline(**inputs).images
image_slice = images[0, :3, :3, -1].flatten()
expected_slice = np.array([0.2705, 0.2395, 0.2209, 0.2312, 0.2102, 0.2104, 0.2178, 0.2065, 0.1997])
assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4)
@slow
@require_torch_gpu
class IPAdapterSDXLIntegrationTests(IPAdapterNightlyTestsMixin):
def test_text_to_image_sdxl(self):
image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="sdxl_models/image_encoder")
feature_extractor = self.get_image_processor("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
pipeline = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
image_encoder=image_encoder,
feature_extractor=feature_extractor,
torch_dtype=self.dtype,
)
pipeline.to(torch_device)
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
inputs = self.get_dummy_inputs()
images = pipeline(**inputs).images
image_slice = images[0, :3, :3, -1].flatten()
expected_slice = np.array([0.0968, 0.0959, 0.0852, 0.0912, 0.0948, 0.093, 0.0893, 0.0932, 0.0923])
assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4)
def test_image_to_image_sdxl(self):
image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="sdxl_models/image_encoder")
feature_extractor = self.get_image_processor("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
pipeline = StableDiffusionXLImg2ImgPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
image_encoder=image_encoder,
feature_extractor=feature_extractor,
torch_dtype=self.dtype,
)
pipeline.to(torch_device)
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
inputs = self.get_dummy_inputs(for_image_to_image=True)
images = pipeline(**inputs).images
image_slice = images[0, :3, :3, -1].flatten()
expected_slice = np.array([0.0653, 0.0704, 0.0725, 0.0741, 0.0702, 0.0647, 0.0782, 0.0799, 0.0752])
assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4)
def test_inpainting_sdxl(self):
image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="sdxl_models/image_encoder")
feature_extractor = self.get_image_processor("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
pipeline = StableDiffusionXLInpaintPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
image_encoder=image_encoder,
feature_extractor=feature_extractor,
torch_dtype=self.dtype,
)
pipeline.to(torch_device)
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
inputs = self.get_dummy_inputs(for_inpainting=True)
images = pipeline(**inputs).images
image_slice = images[0, :3, :3, -1].flatten()
image_slice.tolist()
expected_slice = np.array([0.1418, 0.1493, 0.1428, 0.146, 0.1491, 0.1501, 0.1473, 0.1501, 0.1516])
assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4)
...@@ -163,6 +163,7 @@ class StableDiffusionPipelineFastTests( ...@@ -163,6 +163,7 @@ class StableDiffusionPipelineFastTests(
"tokenizer": tokenizer, "tokenizer": tokenizer,
"safety_checker": None, "safety_checker": None,
"feature_extractor": None, "feature_extractor": None,
"image_encoder": None,
} }
return components return components
......
...@@ -150,6 +150,7 @@ class StableDiffusionImg2ImgPipelineFastTests( ...@@ -150,6 +150,7 @@ class StableDiffusionImg2ImgPipelineFastTests(
"tokenizer": tokenizer, "tokenizer": tokenizer,
"safety_checker": None, "safety_checker": None,
"feature_extractor": None, "feature_extractor": None,
"image_encoder": None,
} }
return components return components
......
...@@ -153,6 +153,7 @@ class StableDiffusionInpaintPipelineFastTests( ...@@ -153,6 +153,7 @@ class StableDiffusionInpaintPipelineFastTests(
"tokenizer": tokenizer, "tokenizer": tokenizer,
"safety_checker": None, "safety_checker": None,
"feature_extractor": None, "feature_extractor": None,
"image_encoder": None,
} }
return components return components
...@@ -353,6 +354,7 @@ class StableDiffusionSimpleInpaintPipelineFastTests(StableDiffusionInpaintPipeli ...@@ -353,6 +354,7 @@ class StableDiffusionSimpleInpaintPipelineFastTests(StableDiffusionInpaintPipeli
"tokenizer": tokenizer, "tokenizer": tokenizer,
"safety_checker": None, "safety_checker": None,
"feature_extractor": None, "feature_extractor": None,
"image_encoder": None,
} }
return components return components
......
...@@ -123,6 +123,7 @@ class StableDiffusion2PipelineFastTests( ...@@ -123,6 +123,7 @@ class StableDiffusion2PipelineFastTests(
"tokenizer": tokenizer, "tokenizer": tokenizer,
"safety_checker": None, "safety_checker": None,
"feature_extractor": None, "feature_extractor": None,
"image_encoder": None,
} }
return components return components
......
...@@ -108,6 +108,7 @@ class StableDiffusion2InpaintPipelineFastTests( ...@@ -108,6 +108,7 @@ class StableDiffusion2InpaintPipelineFastTests(
"tokenizer": tokenizer, "tokenizer": tokenizer,
"safety_checker": None, "safety_checker": None,
"feature_extractor": None, "feature_extractor": None,
"image_encoder": None,
} }
return components return components
......
...@@ -127,6 +127,7 @@ class StableDiffusion2VPredictionPipelineFastTests(unittest.TestCase): ...@@ -127,6 +127,7 @@ class StableDiffusion2VPredictionPipelineFastTests(unittest.TestCase):
tokenizer=tokenizer, tokenizer=tokenizer,
safety_checker=None, safety_checker=None,
feature_extractor=None, feature_extractor=None,
image_encoder=None,
requires_safety_checker=False, requires_safety_checker=False,
) )
sd_pipe = sd_pipe.to(device) sd_pipe = sd_pipe.to(device)
...@@ -176,6 +177,7 @@ class StableDiffusion2VPredictionPipelineFastTests(unittest.TestCase): ...@@ -176,6 +177,7 @@ class StableDiffusion2VPredictionPipelineFastTests(unittest.TestCase):
tokenizer=tokenizer, tokenizer=tokenizer,
safety_checker=None, safety_checker=None,
feature_extractor=None, feature_extractor=None,
image_encoder=None,
requires_safety_checker=False, requires_safety_checker=False,
) )
sd_pipe = sd_pipe.to(device) sd_pipe = sd_pipe.to(device)
...@@ -236,6 +238,7 @@ class StableDiffusion2VPredictionPipelineFastTests(unittest.TestCase): ...@@ -236,6 +238,7 @@ class StableDiffusion2VPredictionPipelineFastTests(unittest.TestCase):
tokenizer=tokenizer, tokenizer=tokenizer,
safety_checker=None, safety_checker=None,
feature_extractor=None, feature_extractor=None,
image_encoder=None,
requires_safety_checker=False, requires_safety_checker=False,
) )
sd_pipe = sd_pipe.to(torch_device) sd_pipe = sd_pipe.to(torch_device)
......
...@@ -131,6 +131,8 @@ class StableDiffusionXLPipelineFastTests( ...@@ -131,6 +131,8 @@ class StableDiffusionXLPipelineFastTests(
"tokenizer": tokenizer, "tokenizer": tokenizer,
"text_encoder_2": text_encoder_2, "text_encoder_2": text_encoder_2,
"tokenizer_2": tokenizer_2, "tokenizer_2": tokenizer_2,
"image_encoder": None,
"feature_extractor": None,
} }
return components return components
......
...@@ -18,7 +18,15 @@ import unittest ...@@ -18,7 +18,15 @@ import unittest
import numpy as np import numpy as np
import torch import torch
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from transformers import (
CLIPImageProcessor,
CLIPTextConfig,
CLIPTextModel,
CLIPTextModelWithProjection,
CLIPTokenizer,
CLIPVisionConfig,
CLIPVisionModelWithProjection,
)
from diffusers import ( from diffusers import (
AutoencoderKL, AutoencoderKL,
...@@ -95,6 +103,31 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel ...@@ -95,6 +103,31 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel
latent_channels=4, latent_channels=4,
sample_size=128, sample_size=128,
) )
torch.manual_seed(0)
image_encoder_config = CLIPVisionConfig(
hidden_size=32,
image_size=224,
projection_dim=32,
intermediate_size=37,
num_attention_heads=4,
num_channels=3,
num_hidden_layers=5,
patch_size=14,
)
image_encoder = CLIPVisionModelWithProjection(image_encoder_config)
feature_extractor = CLIPImageProcessor(
crop_size=224,
do_center_crop=True,
do_normalize=True,
do_resize=True,
image_mean=[0.48145466, 0.4578275, 0.40821073],
image_std=[0.26862954, 0.26130258, 0.27577711],
resample=3,
size=224,
)
torch.manual_seed(0) torch.manual_seed(0)
text_encoder_config = CLIPTextConfig( text_encoder_config = CLIPTextConfig(
bos_token_id=0, bos_token_id=0,
...@@ -125,6 +158,8 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel ...@@ -125,6 +158,8 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel
"text_encoder_2": text_encoder_2, "text_encoder_2": text_encoder_2,
"tokenizer_2": tokenizer_2, "tokenizer_2": tokenizer_2,
"requires_aesthetics_score": True, "requires_aesthetics_score": True,
"image_encoder": image_encoder,
"feature_extractor": feature_extractor,
} }
return components return components
...@@ -458,6 +493,8 @@ class StableDiffusionXLImg2ImgRefinerOnlyPipelineFastTests( ...@@ -458,6 +493,8 @@ class StableDiffusionXLImg2ImgRefinerOnlyPipelineFastTests(
"text_encoder_2": text_encoder_2, "text_encoder_2": text_encoder_2,
"tokenizer_2": tokenizer_2, "tokenizer_2": tokenizer_2,
"requires_aesthetics_score": True, "requires_aesthetics_score": True,
"image_encoder": None,
"feature_extractor": None,
} }
return components return components
......
...@@ -20,7 +20,15 @@ import unittest ...@@ -20,7 +20,15 @@ import unittest
import numpy as np import numpy as np
import torch import torch
from PIL import Image from PIL import Image
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from transformers import (
CLIPImageProcessor,
CLIPTextConfig,
CLIPTextModel,
CLIPTextModelWithProjection,
CLIPTokenizer,
CLIPVisionConfig,
CLIPVisionModelWithProjection,
)
from diffusers import ( from diffusers import (
AutoencoderKL, AutoencoderKL,
...@@ -120,6 +128,31 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel ...@@ -120,6 +128,31 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel
text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config) text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config)
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
torch.manual_seed(0)
image_encoder_config = CLIPVisionConfig(
hidden_size=32,
image_size=224,
projection_dim=32,
intermediate_size=37,
num_attention_heads=4,
num_channels=3,
num_hidden_layers=5,
patch_size=14,
)
image_encoder = CLIPVisionModelWithProjection(image_encoder_config)
feature_extractor = CLIPImageProcessor(
crop_size=224,
do_center_crop=True,
do_normalize=True,
do_resize=True,
image_mean=[0.48145466, 0.4578275, 0.40821073],
image_std=[0.26862954, 0.26130258, 0.27577711],
resample=3,
size=224,
)
components = { components = {
"unet": unet, "unet": unet,
"scheduler": scheduler, "scheduler": scheduler,
...@@ -128,6 +161,8 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel ...@@ -128,6 +161,8 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel
"tokenizer": tokenizer if not skip_first_text_encoder else None, "tokenizer": tokenizer if not skip_first_text_encoder else None,
"text_encoder_2": text_encoder_2, "text_encoder_2": text_encoder_2,
"tokenizer_2": tokenizer_2, "tokenizer_2": tokenizer_2,
"image_encoder": image_encoder,
"feature_extractor": feature_extractor,
"requires_aesthetics_score": True, "requires_aesthetics_score": True,
} }
return components return components
......
...@@ -1136,8 +1136,8 @@ class PipelineFastTests(unittest.TestCase): ...@@ -1136,8 +1136,8 @@ class PipelineFastTests(unittest.TestCase):
safety_checker=None, safety_checker=None,
feature_extractor=self.dummy_extractor, feature_extractor=self.dummy_extractor,
).to(torch_device) ).to(torch_device)
img2img = StableDiffusionImg2ImgPipeline(**inpaint.components).to(torch_device) img2img = StableDiffusionImg2ImgPipeline(**inpaint.components, image_encoder=None).to(torch_device)
text2img = StableDiffusionPipeline(**inpaint.components).to(torch_device) text2img = StableDiffusionPipeline(**inpaint.components, image_encoder=None).to(torch_device)
prompt = "A painting of a squirrel eating a burger" prompt = "A painting of a squirrel eating a burger"
...@@ -1276,6 +1276,29 @@ class PipelineFastTests(unittest.TestCase): ...@@ -1276,6 +1276,29 @@ class PipelineFastTests(unittest.TestCase):
assert out_image.shape == (1, 64, 64, 3) assert out_image.shape == (1, 64, 64, 3)
assert np.abs(out_image - out_image_2).max() < 1e-3 assert np.abs(out_image - out_image_2).max() < 1e-3
def test_optional_components_is_none(self):
unet = self.dummy_cond_unet()
scheduler = PNDMScheduler(skip_prk_steps=True)
vae = self.dummy_vae
bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
items = {
"feature_extractor": self.dummy_extractor,
"unet": unet,
"scheduler": scheduler,
"vae": vae,
"text_encoder": bert,
"tokenizer": tokenizer,
"safety_checker": None,
# we don't add an image encoder
}
pipeline = StableDiffusionPipeline(**items)
assert sorted(pipeline.components.keys()) == sorted(["image_encoder"] + list(items.keys()))
assert pipeline.image_encoder is None
def test_set_scheduler_consistency(self): def test_set_scheduler_consistency(self):
unet = self.dummy_cond_unet() unet = self.dummy_cond_unet()
pndm = PNDMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler") pndm = PNDMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment