Unverified Commit ba352aea authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

[feat] IP Adapters (author @okotaku ) (#5713)



* add ip-adapter


---------
Co-authored-by: default avatarokotaku <to78314910@gmail.com>
Co-authored-by: default avatarsayakpaul <spsayakpaul@gmail.com>
Co-authored-by: default avataryiyixuxu <yixu310@gmail,com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>
parent 6fac1369
......@@ -1221,6 +1221,15 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
)
image_embeds = added_cond_kwargs.get("image_embeds")
encoder_hidden_states = self.encoder_hid_proj(image_embeds)
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
if "image_embeds" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
)
image_embeds = added_cond_kwargs.get("image_embeds")
image_embeds = self.encoder_hid_proj(image_embeds).to(encoder_hidden_states.dtype)
encoder_hidden_states = torch.cat([encoder_hidden_states, image_embeds], dim=1)
# 2. pre-process
sample = self.conv_in(sample)
......
......@@ -246,6 +246,7 @@ class LoraLoaderMixinTests(unittest.TestCase):
"tokenizer": tokenizer,
"safety_checker": None,
"feature_extractor": None,
"image_encoder": None,
}
lora_components = {
"unet_lora_layers": unet_lora_layers,
......@@ -757,6 +758,7 @@ class SDXInpaintLoraMixinTests(unittest.TestCase):
"tokenizer": tokenizer,
"safety_checker": None,
"feature_extractor": None,
"image_encoder": None,
}
return components
......@@ -866,6 +868,8 @@ class SDXLLoraLoaderMixinTests(unittest.TestCase):
"text_encoder_2": text_encoder_2,
"tokenizer": tokenizer,
"tokenizer_2": tokenizer_2,
"image_encoder": None,
"feature_extractor": None,
}
lora_components = {
"unet_lora_layers": unet_lora_layers,
......
......@@ -140,6 +140,8 @@ class PeftLoraLoaderMixinTests:
"tokenizer": tokenizer,
"text_encoder_2": text_encoder_2,
"tokenizer_2": tokenizer_2,
"image_encoder": None,
"feature_extractor": None,
}
else:
pipeline_components = {
......@@ -150,6 +152,7 @@ class PeftLoraLoaderMixinTests:
"tokenizer": tokenizer,
"safety_checker": None,
"feature_extractor": None,
"image_encoder": None,
}
lora_components = {
"unet_lora_layers": unet_lora_layers,
......
......@@ -24,7 +24,8 @@ from parameterized import parameterized
from pytest import mark
from diffusers import UNet2DConditionModel
from diffusers.models.attention_processor import CustomDiffusionAttnProcessor
from diffusers.models.attention_processor import CustomDiffusionAttnProcessor, IPAdapterAttnProcessor
from diffusers.models.embeddings import ImageProjection
from diffusers.utils import logging
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import (
......@@ -45,6 +46,57 @@ logger = logging.get_logger(__name__)
enable_full_determinism()
def create_ip_adapter_state_dict(model):
# "ip_adapter" (cross-attention weights)
ip_cross_attn_state_dict = {}
key_id = 1
for name in model.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = model.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(model.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = model.config.block_out_channels[block_id]
if cross_attention_dim is not None:
sd = IPAdapterAttnProcessor(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0
).state_dict()
ip_cross_attn_state_dict.update(
{
f"{key_id}.to_k_ip.weight": sd["to_k_ip.weight"],
f"{key_id}.to_v_ip.weight": sd["to_v_ip.weight"],
}
)
key_id += 2
# "image_proj" (ImageProjection layer weights)
cross_attention_dim = model.config["cross_attention_dim"]
image_projection = ImageProjection(
cross_attention_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, num_image_text_embeds=4
)
ip_image_projection_state_dict = {}
sd = image_projection.state_dict()
ip_image_projection_state_dict.update(
{
"proj.weight": sd["image_embeds.weight"],
"proj.bias": sd["image_embeds.bias"],
"norm.weight": sd["norm.weight"],
"norm.bias": sd["norm.bias"],
}
)
del sd
ip_state_dict = {}
ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict})
return ip_state_dict
def create_custom_diffusion_layers(model, mock_weights: bool = True):
train_kv = True
train_q_out = True
......@@ -622,6 +674,56 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
# Check if input and output shapes are the same
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_ip_adapter(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["attention_head_dim"] = (8, 16)
model = self.model_class(**init_dict)
model.to(torch_device)
# forward pass without ip-adapter
with torch.no_grad():
sample1 = model(**inputs_dict).sample
# update inputs_dict for ip-adapter
batch_size = inputs_dict["encoder_hidden_states"].shape[0]
image_embeds = floats_tensor((batch_size, 1, model.cross_attention_dim)).to(torch_device)
inputs_dict["added_cond_kwargs"] = {"image_embeds": image_embeds}
# make ip_adapter_1 and ip_adapter_2
ip_adapter_1 = create_ip_adapter_state_dict(model)
image_proj_state_dict_2 = {k: w + 1.0 for k, w in ip_adapter_1["image_proj"].items()}
cross_attn_state_dict_2 = {k: w + 1.0 for k, w in ip_adapter_1["ip_adapter"].items()}
ip_adapter_2 = {}
ip_adapter_2.update({"image_proj": image_proj_state_dict_2, "ip_adapter": cross_attn_state_dict_2})
# forward pass ip_adapter_1
model._load_ip_adapter_weights(ip_adapter_1)
assert model.config.encoder_hid_dim_type == "ip_image_proj"
assert model.encoder_hid_proj is not None
assert model.down_blocks[0].attentions[0].transformer_blocks[0].attn2.processor.__class__.__name__ in (
"IPAdapterAttnProcessor",
"IPAdapterAttnProcessor2_0",
)
with torch.no_grad():
sample2 = model(**inputs_dict).sample
# forward pass with ip_adapter_2
model._load_ip_adapter_weights(ip_adapter_2)
with torch.no_grad():
sample3 = model(**inputs_dict).sample
# forward pass with ip_adapter_1 again
model._load_ip_adapter_weights(ip_adapter_1)
with torch.no_grad():
sample4 = model(**inputs_dict).sample
assert not sample1.allclose(sample2, atol=1e-4, rtol=1e-4)
assert not sample2.allclose(sample3, atol=1e-4, rtol=1e-4)
assert sample2.allclose(sample4, atol=1e-4, rtol=1e-4)
@slow
class UNet2DConditionModelIntegrationTests(unittest.TestCase):
......
......@@ -117,6 +117,7 @@ class AltDiffusionPipelineFastTests(
"tokenizer": tokenizer,
"safety_checker": None,
"feature_extractor": None,
"image_encoder": None,
}
return components
......
......@@ -141,6 +141,7 @@ class AltDiffusionImg2ImgPipelineFastTests(unittest.TestCase):
tokenizer=tokenizer,
safety_checker=None,
feature_extractor=self.dummy_extractor,
image_encoder=None,
)
alt_pipe.image_processor = VaeImageProcessor(vae_scale_factor=alt_pipe.vae_scale_factor, do_normalize=True)
alt_pipe = alt_pipe.to(device)
......@@ -205,6 +206,7 @@ class AltDiffusionImg2ImgPipelineFastTests(unittest.TestCase):
tokenizer=tokenizer,
safety_checker=None,
feature_extractor=self.dummy_extractor,
image_encoder=None,
)
alt_pipe.image_processor = VaeImageProcessor(vae_scale_factor=alt_pipe.vae_scale_factor, do_normalize=False)
alt_pipe = alt_pipe.to(torch_device)
......
......@@ -99,6 +99,8 @@ class AnimateDiffPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
"motion_adapter": motion_adapter,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"feature_extractor": None,
"image_encoder": None,
}
return components
......
......@@ -183,6 +183,7 @@ class ControlNetPipelineFastTests(
"tokenizer": tokenizer,
"safety_checker": None,
"feature_extractor": None,
"image_encoder": None,
}
return components
......@@ -341,6 +342,7 @@ class StableDiffusionMultiControlNetPipelineFastTests(
"tokenizer": tokenizer,
"safety_checker": None,
"feature_extractor": None,
"image_encoder": None,
}
return components
......@@ -518,6 +520,7 @@ class StableDiffusionMultiControlNetOneModelPipelineFastTests(
"tokenizer": tokenizer,
"safety_checker": None,
"feature_extractor": None,
"image_encoder": None,
}
return components
......
......@@ -146,6 +146,8 @@ class StableDiffusionXLControlNetPipelineFastTests(
"tokenizer": tokenizer,
"text_encoder_2": text_encoder_2,
"tokenizer_2": tokenizer_2,
"feature_extractor": None,
"image_encoder": None,
}
return components
......@@ -471,6 +473,8 @@ class StableDiffusionXLMultiControlNetPipelineFastTests(
"tokenizer": tokenizer,
"text_encoder_2": text_encoder_2,
"tokenizer_2": tokenizer_2,
"feature_extractor": None,
"image_encoder": None,
}
return components
......@@ -656,6 +660,8 @@ class StableDiffusionXLMultiControlNetOneModelPipelineFastTests(
"tokenizer": tokenizer,
"text_encoder_2": text_encoder_2,
"tokenizer_2": tokenizer_2,
"feature_extractor": None,
"image_encoder": None,
}
return components
......
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import unittest
import numpy as np
import torch
from transformers import (
CLIPImageProcessor,
CLIPVisionModelWithProjection,
)
from diffusers import (
StableDiffusionImg2ImgPipeline,
StableDiffusionInpaintPipeline,
StableDiffusionPipeline,
StableDiffusionXLImg2ImgPipeline,
StableDiffusionXLInpaintPipeline,
StableDiffusionXLPipeline,
)
from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
enable_full_determinism,
require_torch_gpu,
slow,
torch_device,
)
enable_full_determinism()
class IPAdapterNightlyTestsMixin(unittest.TestCase):
dtype = torch.float16
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def get_image_encoder(self, repo_id, subfolder):
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
repo_id, subfolder=subfolder, torch_dtype=self.dtype
).to(torch_device)
return image_encoder
def get_image_processor(self, repo_id):
image_processor = CLIPImageProcessor.from_pretrained(repo_id)
return image_processor
def get_dummy_inputs(self, for_image_to_image=False, for_inpainting=False, for_sdxl=False):
image = load_image(
"https://user-images.githubusercontent.com/24734142/266492875-2d50d223-8475-44f0-a7c6-08b51cb53572.png"
)
if for_sdxl:
image = image.resize((1024, 1024))
input_kwargs = {
"prompt": "best quality, high quality",
"negative_prompt": "monochrome, lowres, bad anatomy, worst quality, low quality",
"num_inference_steps": 5,
"generator": torch.Generator(device="cpu").manual_seed(33),
"ip_adapter_image": image,
"output_type": "np",
}
if for_image_to_image:
image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/vermeer.jpg")
ip_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/river.png")
if for_sdxl:
image = image.resize((1024, 1024))
ip_image = ip_image.resize((1024, 1024))
input_kwargs.update({"image": image, "ip_adapter_image": ip_image})
elif for_inpainting:
image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/inpaint_image.png")
mask = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/mask.png")
ip_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/girl.png")
if for_sdxl:
image = image.resize((1024, 1024))
mask = mask.resize((1024, 1024))
ip_image = ip_image.resize((1024, 1024))
input_kwargs.update({"image": image, "mask_image": mask, "ip_adapter_image": ip_image})
return input_kwargs
@slow
@require_torch_gpu
class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin):
def test_text_to_image(self):
image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder")
pipeline = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", image_encoder=image_encoder, safety_checker=None, torch_dtype=self.dtype
)
pipeline.to(torch_device)
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")
inputs = self.get_dummy_inputs()
images = pipeline(**inputs).images
image_slice = images[0, :3, :3, -1].flatten()
expected_slice = np.array([0.8047, 0.8774, 0.9248, 0.9155, 0.9814, 1.0, 0.9678, 1.0, 1.0])
assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4)
def test_image_to_image(self):
image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder")
pipeline = StableDiffusionImg2ImgPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", image_encoder=image_encoder, safety_checker=None, torch_dtype=self.dtype
)
pipeline.to(torch_device)
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")
inputs = self.get_dummy_inputs(for_image_to_image=True)
images = pipeline(**inputs).images
image_slice = images[0, :3, :3, -1].flatten()
expected_slice = np.array([0.2307, 0.2341, 0.2305, 0.24, 0.2268, 0.25, 0.2322, 0.2588, 0.2935])
assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4)
def test_inpainting(self):
image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder")
pipeline = StableDiffusionInpaintPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", image_encoder=image_encoder, safety_checker=None, torch_dtype=self.dtype
)
pipeline.to(torch_device)
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")
inputs = self.get_dummy_inputs(for_inpainting=True)
images = pipeline(**inputs).images
image_slice = images[0, :3, :3, -1].flatten()
expected_slice = np.array([0.2705, 0.2395, 0.2209, 0.2312, 0.2102, 0.2104, 0.2178, 0.2065, 0.1997])
assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4)
@slow
@require_torch_gpu
class IPAdapterSDXLIntegrationTests(IPAdapterNightlyTestsMixin):
def test_text_to_image_sdxl(self):
image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="sdxl_models/image_encoder")
feature_extractor = self.get_image_processor("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
pipeline = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
image_encoder=image_encoder,
feature_extractor=feature_extractor,
torch_dtype=self.dtype,
)
pipeline.to(torch_device)
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
inputs = self.get_dummy_inputs()
images = pipeline(**inputs).images
image_slice = images[0, :3, :3, -1].flatten()
expected_slice = np.array([0.0968, 0.0959, 0.0852, 0.0912, 0.0948, 0.093, 0.0893, 0.0932, 0.0923])
assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4)
def test_image_to_image_sdxl(self):
image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="sdxl_models/image_encoder")
feature_extractor = self.get_image_processor("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
pipeline = StableDiffusionXLImg2ImgPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
image_encoder=image_encoder,
feature_extractor=feature_extractor,
torch_dtype=self.dtype,
)
pipeline.to(torch_device)
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
inputs = self.get_dummy_inputs(for_image_to_image=True)
images = pipeline(**inputs).images
image_slice = images[0, :3, :3, -1].flatten()
expected_slice = np.array([0.0653, 0.0704, 0.0725, 0.0741, 0.0702, 0.0647, 0.0782, 0.0799, 0.0752])
assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4)
def test_inpainting_sdxl(self):
image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="sdxl_models/image_encoder")
feature_extractor = self.get_image_processor("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
pipeline = StableDiffusionXLInpaintPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
image_encoder=image_encoder,
feature_extractor=feature_extractor,
torch_dtype=self.dtype,
)
pipeline.to(torch_device)
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
inputs = self.get_dummy_inputs(for_inpainting=True)
images = pipeline(**inputs).images
image_slice = images[0, :3, :3, -1].flatten()
image_slice.tolist()
expected_slice = np.array([0.1418, 0.1493, 0.1428, 0.146, 0.1491, 0.1501, 0.1473, 0.1501, 0.1516])
assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4)
......@@ -163,6 +163,7 @@ class StableDiffusionPipelineFastTests(
"tokenizer": tokenizer,
"safety_checker": None,
"feature_extractor": None,
"image_encoder": None,
}
return components
......
......@@ -150,6 +150,7 @@ class StableDiffusionImg2ImgPipelineFastTests(
"tokenizer": tokenizer,
"safety_checker": None,
"feature_extractor": None,
"image_encoder": None,
}
return components
......
......@@ -153,6 +153,7 @@ class StableDiffusionInpaintPipelineFastTests(
"tokenizer": tokenizer,
"safety_checker": None,
"feature_extractor": None,
"image_encoder": None,
}
return components
......@@ -353,6 +354,7 @@ class StableDiffusionSimpleInpaintPipelineFastTests(StableDiffusionInpaintPipeli
"tokenizer": tokenizer,
"safety_checker": None,
"feature_extractor": None,
"image_encoder": None,
}
return components
......
......@@ -123,6 +123,7 @@ class StableDiffusion2PipelineFastTests(
"tokenizer": tokenizer,
"safety_checker": None,
"feature_extractor": None,
"image_encoder": None,
}
return components
......
......@@ -108,6 +108,7 @@ class StableDiffusion2InpaintPipelineFastTests(
"tokenizer": tokenizer,
"safety_checker": None,
"feature_extractor": None,
"image_encoder": None,
}
return components
......
......@@ -127,6 +127,7 @@ class StableDiffusion2VPredictionPipelineFastTests(unittest.TestCase):
tokenizer=tokenizer,
safety_checker=None,
feature_extractor=None,
image_encoder=None,
requires_safety_checker=False,
)
sd_pipe = sd_pipe.to(device)
......@@ -176,6 +177,7 @@ class StableDiffusion2VPredictionPipelineFastTests(unittest.TestCase):
tokenizer=tokenizer,
safety_checker=None,
feature_extractor=None,
image_encoder=None,
requires_safety_checker=False,
)
sd_pipe = sd_pipe.to(device)
......@@ -236,6 +238,7 @@ class StableDiffusion2VPredictionPipelineFastTests(unittest.TestCase):
tokenizer=tokenizer,
safety_checker=None,
feature_extractor=None,
image_encoder=None,
requires_safety_checker=False,
)
sd_pipe = sd_pipe.to(torch_device)
......
......@@ -131,6 +131,8 @@ class StableDiffusionXLPipelineFastTests(
"tokenizer": tokenizer,
"text_encoder_2": text_encoder_2,
"tokenizer_2": tokenizer_2,
"image_encoder": None,
"feature_extractor": None,
}
return components
......
......@@ -18,7 +18,15 @@ import unittest
import numpy as np
import torch
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from transformers import (
CLIPImageProcessor,
CLIPTextConfig,
CLIPTextModel,
CLIPTextModelWithProjection,
CLIPTokenizer,
CLIPVisionConfig,
CLIPVisionModelWithProjection,
)
from diffusers import (
AutoencoderKL,
......@@ -95,6 +103,31 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel
latent_channels=4,
sample_size=128,
)
torch.manual_seed(0)
image_encoder_config = CLIPVisionConfig(
hidden_size=32,
image_size=224,
projection_dim=32,
intermediate_size=37,
num_attention_heads=4,
num_channels=3,
num_hidden_layers=5,
patch_size=14,
)
image_encoder = CLIPVisionModelWithProjection(image_encoder_config)
feature_extractor = CLIPImageProcessor(
crop_size=224,
do_center_crop=True,
do_normalize=True,
do_resize=True,
image_mean=[0.48145466, 0.4578275, 0.40821073],
image_std=[0.26862954, 0.26130258, 0.27577711],
resample=3,
size=224,
)
torch.manual_seed(0)
text_encoder_config = CLIPTextConfig(
bos_token_id=0,
......@@ -125,6 +158,8 @@ class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, Pipel
"text_encoder_2": text_encoder_2,
"tokenizer_2": tokenizer_2,
"requires_aesthetics_score": True,
"image_encoder": image_encoder,
"feature_extractor": feature_extractor,
}
return components
......@@ -458,6 +493,8 @@ class StableDiffusionXLImg2ImgRefinerOnlyPipelineFastTests(
"text_encoder_2": text_encoder_2,
"tokenizer_2": tokenizer_2,
"requires_aesthetics_score": True,
"image_encoder": None,
"feature_extractor": None,
}
return components
......
......@@ -20,7 +20,15 @@ import unittest
import numpy as np
import torch
from PIL import Image
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from transformers import (
CLIPImageProcessor,
CLIPTextConfig,
CLIPTextModel,
CLIPTextModelWithProjection,
CLIPTokenizer,
CLIPVisionConfig,
CLIPVisionModelWithProjection,
)
from diffusers import (
AutoencoderKL,
......@@ -120,6 +128,31 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel
text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config)
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
torch.manual_seed(0)
image_encoder_config = CLIPVisionConfig(
hidden_size=32,
image_size=224,
projection_dim=32,
intermediate_size=37,
num_attention_heads=4,
num_channels=3,
num_hidden_layers=5,
patch_size=14,
)
image_encoder = CLIPVisionModelWithProjection(image_encoder_config)
feature_extractor = CLIPImageProcessor(
crop_size=224,
do_center_crop=True,
do_normalize=True,
do_resize=True,
image_mean=[0.48145466, 0.4578275, 0.40821073],
image_std=[0.26862954, 0.26130258, 0.27577711],
resample=3,
size=224,
)
components = {
"unet": unet,
"scheduler": scheduler,
......@@ -128,6 +161,8 @@ class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, Pipel
"tokenizer": tokenizer if not skip_first_text_encoder else None,
"text_encoder_2": text_encoder_2,
"tokenizer_2": tokenizer_2,
"image_encoder": image_encoder,
"feature_extractor": feature_extractor,
"requires_aesthetics_score": True,
}
return components
......
......@@ -1136,8 +1136,8 @@ class PipelineFastTests(unittest.TestCase):
safety_checker=None,
feature_extractor=self.dummy_extractor,
).to(torch_device)
img2img = StableDiffusionImg2ImgPipeline(**inpaint.components).to(torch_device)
text2img = StableDiffusionPipeline(**inpaint.components).to(torch_device)
img2img = StableDiffusionImg2ImgPipeline(**inpaint.components, image_encoder=None).to(torch_device)
text2img = StableDiffusionPipeline(**inpaint.components, image_encoder=None).to(torch_device)
prompt = "A painting of a squirrel eating a burger"
......@@ -1276,6 +1276,29 @@ class PipelineFastTests(unittest.TestCase):
assert out_image.shape == (1, 64, 64, 3)
assert np.abs(out_image - out_image_2).max() < 1e-3
def test_optional_components_is_none(self):
unet = self.dummy_cond_unet()
scheduler = PNDMScheduler(skip_prk_steps=True)
vae = self.dummy_vae
bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
items = {
"feature_extractor": self.dummy_extractor,
"unet": unet,
"scheduler": scheduler,
"vae": vae,
"text_encoder": bert,
"tokenizer": tokenizer,
"safety_checker": None,
# we don't add an image encoder
}
pipeline = StableDiffusionPipeline(**items)
assert sorted(pipeline.components.keys()) == sorted(["image_encoder"] + list(items.keys()))
assert pipeline.image_encoder is None
def test_set_scheduler_consistency(self):
unet = self.dummy_cond_unet()
pndm = PNDMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment