Unverified Commit 2e8d18e6 authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

[IP-Adapter] Support multiple IP-Adapters (#6573)




---------
Co-authored-by: default avataryiyixuxu <yixu310@gmail,com>
Co-authored-by: default avatarAlvaro Somoza <somoza.alvaro@gmail.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 03373de0
......@@ -766,6 +766,35 @@ class StableDiffusionXLImg2ImgPipeline(
return image_embeds, uncond_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt):
if not isinstance(ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image]
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
raise ValueError(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
)
image_embeds = []
for single_ip_adapter_image, image_proj_layer in zip(
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
):
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
single_image_embeds, single_negative_image_embeds = self.encode_image(
single_ip_adapter_image, device, 1, output_hidden_state
)
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0)
if self.do_classifier_free_guidance:
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
single_image_embeds = single_image_embeds.to(device)
image_embeds.append(single_image_embeds)
return image_embeds
def _get_add_time_ids(
self,
original_size,
......@@ -1337,13 +1366,9 @@ class StableDiffusionXLImg2ImgPipeline(
add_time_ids = add_time_ids.to(device)
if ip_adapter_image is not None:
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
image_embeds, negative_image_embeds = self.encode_image(
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
image_embeds = self.prepare_ip_adapter_image_embeds(
ip_adapter_image, device, batch_size * num_images_per_prompt
)
if self.do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds])
image_embeds = image_embeds.to(device)
# 9. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
......
......@@ -487,6 +487,35 @@ class StableDiffusionXLInpaintPipeline(
return image_embeds, uncond_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt):
if not isinstance(ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image]
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
raise ValueError(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
)
image_embeds = []
for single_ip_adapter_image, image_proj_layer in zip(
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
):
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
single_image_embeds, single_negative_image_embeds = self.encode_image(
single_ip_adapter_image, device, 1, output_hidden_state
)
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0)
if self.do_classifier_free_guidance:
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
single_image_embeds = single_image_embeds.to(device)
image_embeds.append(single_image_embeds)
return image_embeds
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
def encode_prompt(
self,
......@@ -1685,13 +1714,9 @@ class StableDiffusionXLInpaintPipeline(
add_time_ids = add_time_ids.to(device)
if ip_adapter_image is not None:
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
image_embeds, negative_image_embeds = self.encode_image(
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
image_embeds = self.prepare_ip_adapter_image_embeds(
ip_adapter_image, device, batch_size * num_images_per_prompt
)
if self.do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds])
image_embeds = image_embeds.to(device)
# 11. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
......
......@@ -563,6 +563,35 @@ class StableDiffusionXLAdapterPipeline(
return image_embeds, uncond_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
def prepare_ip_adapter_image_embeds(self, ip_adapter_image, device, num_images_per_prompt):
if not isinstance(ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image]
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
raise ValueError(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
)
image_embeds = []
for single_ip_adapter_image, image_proj_layer in zip(
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
):
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
single_image_embeds, single_negative_image_embeds = self.encode_image(
single_ip_adapter_image, device, 1, output_hidden_state
)
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
single_negative_image_embeds = torch.stack([single_negative_image_embeds] * num_images_per_prompt, dim=0)
if self.do_classifier_free_guidance:
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
single_image_embeds = single_image_embeds.to(device)
image_embeds.append(single_image_embeds)
return image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
......@@ -1068,12 +1097,9 @@ class StableDiffusionXLAdapterPipeline(
# 3.2 Encode ip_adapter_image
if ip_adapter_image is not None:
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
image_embeds, negative_image_embeds = self.encode_image(
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
image_embeds = self.prepare_ip_adapter_image_embeds(
ip_adapter_image, device, batch_size * num_images_per_prompt
)
if self.do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds])
# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
......
......@@ -25,7 +25,11 @@ from parameterized import parameterized
from pytest import mark
from diffusers import UNet2DConditionModel
from diffusers.models.attention_processor import CustomDiffusionAttnProcessor, IPAdapterAttnProcessor
from diffusers.models.attention_processor import (
CustomDiffusionAttnProcessor,
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
)
from diffusers.models.embeddings import ImageProjection, IPAdapterPlusImageProjection
from diffusers.utils import logging
from diffusers.utils.import_utils import is_xformers_available
......@@ -73,8 +77,8 @@ def create_ip_adapter_state_dict(model):
).state_dict()
ip_cross_attn_state_dict.update(
{
f"{key_id}.to_k_ip.weight": sd["to_k_ip.weight"],
f"{key_id}.to_v_ip.weight": sd["to_v_ip.weight"],
f"{key_id}.to_k_ip.weight": sd["to_k_ip.0.weight"],
f"{key_id}.to_v_ip.weight": sd["to_v_ip.0.weight"],
}
)
......@@ -124,8 +128,8 @@ def create_ip_adapter_plus_state_dict(model):
).state_dict()
ip_cross_attn_state_dict.update(
{
f"{key_id}.to_k_ip.weight": sd["to_k_ip.weight"],
f"{key_id}.to_v_ip.weight": sd["to_v_ip.weight"],
f"{key_id}.to_k_ip.weight": sd["to_k_ip.0.weight"],
f"{key_id}.to_v_ip.weight": sd["to_v_ip.0.weight"],
}
)
......@@ -773,8 +777,9 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
# update inputs_dict for ip-adapter
batch_size = inputs_dict["encoder_hidden_states"].shape[0]
# for ip-adapter image_embeds has shape [batch_size, num_image, embed_dim]
image_embeds = floats_tensor((batch_size, 1, model.cross_attention_dim)).to(torch_device)
inputs_dict["added_cond_kwargs"] = {"image_embeds": image_embeds}
inputs_dict["added_cond_kwargs"] = {"image_embeds": [image_embeds]}
# make ip_adapter_1 and ip_adapter_2
ip_adapter_1 = create_ip_adapter_state_dict(model)
......@@ -785,7 +790,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
ip_adapter_2.update({"image_proj": image_proj_state_dict_2, "ip_adapter": cross_attn_state_dict_2})
# forward pass ip_adapter_1
model._load_ip_adapter_weights(ip_adapter_1)
model._load_ip_adapter_weights([ip_adapter_1])
assert model.config.encoder_hid_dim_type == "ip_image_proj"
assert model.encoder_hid_proj is not None
assert model.down_blocks[0].attentions[0].transformer_blocks[0].attn2.processor.__class__.__name__ in (
......@@ -796,18 +801,39 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
sample2 = model(**inputs_dict).sample
# forward pass with ip_adapter_2
model._load_ip_adapter_weights(ip_adapter_2)
model._load_ip_adapter_weights([ip_adapter_2])
with torch.no_grad():
sample3 = model(**inputs_dict).sample
# forward pass with ip_adapter_1 again
model._load_ip_adapter_weights(ip_adapter_1)
model._load_ip_adapter_weights([ip_adapter_1])
with torch.no_grad():
sample4 = model(**inputs_dict).sample
# forward pass with multiple ip-adapters and multiple images
model._load_ip_adapter_weights([ip_adapter_1, ip_adapter_2])
# set the scale for ip_adapter_2 to 0 so that result should be same as only load ip_adapter_1
for attn_processor in model.attn_processors.values():
if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)):
attn_processor.scale = [1, 0]
image_embeds_multi = image_embeds.repeat(1, 2, 1)
inputs_dict["added_cond_kwargs"] = {"image_embeds": [image_embeds_multi, image_embeds_multi]}
with torch.no_grad():
sample5 = model(**inputs_dict).sample
# forward pass with single ip-adapter & single image when image_embeds is not a list and a 2-d tensor
image_embeds = image_embeds.squeeze(1)
inputs_dict["added_cond_kwargs"] = {"image_embeds": image_embeds}
model._load_ip_adapter_weights(ip_adapter_1)
with torch.no_grad():
sample6 = model(**inputs_dict).sample
assert not sample1.allclose(sample2, atol=1e-4, rtol=1e-4)
assert not sample2.allclose(sample3, atol=1e-4, rtol=1e-4)
assert sample2.allclose(sample4, atol=1e-4, rtol=1e-4)
assert sample2.allclose(sample5, atol=1e-4, rtol=1e-4)
assert sample2.allclose(sample6, atol=1e-4, rtol=1e-4)
def test_ip_adapter_plus(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
......@@ -823,8 +849,9 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
# update inputs_dict for ip-adapter
batch_size = inputs_dict["encoder_hidden_states"].shape[0]
image_embeds = floats_tensor((batch_size, 1, model.cross_attention_dim)).to(torch_device)
inputs_dict["added_cond_kwargs"] = {"image_embeds": image_embeds}
# for ip-adapter-plus image_embeds has shape [batch_size, num_image, sequence_length, embed_dim]
image_embeds = floats_tensor((batch_size, 1, 1, model.cross_attention_dim)).to(torch_device)
inputs_dict["added_cond_kwargs"] = {"image_embeds": [image_embeds]}
# make ip_adapter_1 and ip_adapter_2
ip_adapter_1 = create_ip_adapter_plus_state_dict(model)
......@@ -835,7 +862,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
ip_adapter_2.update({"image_proj": image_proj_state_dict_2, "ip_adapter": cross_attn_state_dict_2})
# forward pass ip_adapter_1
model._load_ip_adapter_weights(ip_adapter_1)
model._load_ip_adapter_weights([ip_adapter_1])
assert model.config.encoder_hid_dim_type == "ip_image_proj"
assert model.encoder_hid_proj is not None
assert model.down_blocks[0].attentions[0].transformer_blocks[0].attn2.processor.__class__.__name__ in (
......@@ -846,18 +873,39 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
sample2 = model(**inputs_dict).sample
# forward pass with ip_adapter_2
model._load_ip_adapter_weights(ip_adapter_2)
model._load_ip_adapter_weights([ip_adapter_2])
with torch.no_grad():
sample3 = model(**inputs_dict).sample
# forward pass with ip_adapter_1 again
model._load_ip_adapter_weights(ip_adapter_1)
model._load_ip_adapter_weights([ip_adapter_1])
with torch.no_grad():
sample4 = model(**inputs_dict).sample
# forward pass with multiple ip-adapters and multiple images
model._load_ip_adapter_weights([ip_adapter_1, ip_adapter_2])
# set the scale for ip_adapter_2 to 0 so that result should be same as only load ip_adapter_1
for attn_processor in model.attn_processors.values():
if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)):
attn_processor.scale = [1, 0]
image_embeds_multi = image_embeds.repeat(1, 2, 1, 1)
inputs_dict["added_cond_kwargs"] = {"image_embeds": [image_embeds_multi, image_embeds_multi]}
with torch.no_grad():
sample5 = model(**inputs_dict).sample
# forward pass with single ip-adapter & single image when image_embeds is a 3-d tensor
image_embeds = image_embeds[:,].squeeze(1)
inputs_dict["added_cond_kwargs"] = {"image_embeds": image_embeds}
model._load_ip_adapter_weights(ip_adapter_1)
with torch.no_grad():
sample6 = model(**inputs_dict).sample
assert not sample1.allclose(sample2, atol=1e-4, rtol=1e-4)
assert not sample2.allclose(sample3, atol=1e-4, rtol=1e-4)
assert sample2.allclose(sample4, atol=1e-4, rtol=1e-4)
assert sample2.allclose(sample5, atol=1e-4, rtol=1e-4)
assert sample2.allclose(sample6, atol=1e-4, rtol=1e-4)
@slow
......
......@@ -258,6 +258,27 @@ class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin):
]
assert processors == [True] * len(processors)
def test_multi(self):
image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder")
pipeline = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", image_encoder=image_encoder, safety_checker=None, torch_dtype=self.dtype
)
pipeline.to(torch_device)
pipeline.load_ip_adapter(
"h94/IP-Adapter", subfolder="models", weight_name=["ip-adapter_sd15.bin", "ip-adapter-plus_sd15.bin"]
)
pipeline.set_ip_adapter_scale([0.7, 0.3])
inputs = self.get_dummy_inputs()
ip_adapter_image = inputs["ip_adapter_image"]
inputs["ip_adapter_image"] = [ip_adapter_image, [ip_adapter_image] * 2]
images = pipeline(**inputs).images
image_slice = images[0, :3, :3, -1].flatten()
expected_slice = np.array(
[0.5234375, 0.53515625, 0.5629883, 0.57128906, 0.59521484, 0.62109375, 0.57910156, 0.6201172, 0.6508789]
)
assert np.allclose(image_slice, expected_slice, atol=1e-3)
@slow
@require_torch_gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment