Unverified Commit 21c747fa authored by Jenyuan-Huang's avatar Jenyuan-Huang Committed by GitHub
Browse files

Support InstantStyle (#7668)



* enable control ip-adapter per-transformer block on-the-fly

---------
Co-authored-by: default avatarsayakpaul <spsayakpaul@gmail.com>
Co-authored-by: default avatarResearcherXman <xhs.research@gmail.com>
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
parent 09129842
...@@ -640,3 +640,87 @@ image ...@@ -640,3 +640,87 @@ image
<div class="flex justify-center"> <div class="flex justify-center">
    <img src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ipa-controlnet-out.png" />     <img src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ipa-controlnet-out.png" />
</div> </div>
### Style & layout control
[InstantStyle](https://arxiv.org/abs/2404.02733) is a plug-and-play method on top of IP-Adapter, which disentangles style and layout from image prompt to control image generation. This is achieved by only inserting IP-Adapters to some specific part of the model.
By default IP-Adapters are inserted to all layers of the model. Use the [`~loaders.IPAdapterMixin.set_ip_adapter_scale`] method with a dictionary to assign scales to IP-Adapter at different layers.
```py
from diffusers import AutoPipelineForImage2Image
from diffusers.utils import load_image
import torch
pipeline = AutoPipelineForImage2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16).to("cuda")
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
scale = {
"down": {"block_2": [0.0, 1.0]},
"up": {"block_0": [0.0, 1.0, 0.0]},
}
pipeline.set_ip_adapter_scale(scale)
```
This will activate IP-Adapter at the second layer in the model's down-part block 2 and up-part block 0. The former is the layer where IP-Adapter injects layout information and the latter injects style. Inserting IP-Adapter to these two layers you can generate images following the style and layout of image prompt, but with contents more aligned to text prompt.
```py
style_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg")
generator = torch.Generator(device="cpu").manual_seed(42)
image = pipeline(
prompt="a cat, masterpiece, best quality, high quality",
image=style_image,
negative_prompt="text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry",
guidance_scale=5,
num_inference_steps=30,
generator=generator,
).images[0]
image
```
<div class="flex flex-row gap-4">
<div class="flex-1">
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">IP-Adapter image</figcaption>
</div>
<div class="flex-1">
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit_style_layout_cat.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">generated image</figcaption>
</div>
</div>
In contrast, inserting IP-Adapter to all layers will often generate images that overly focus on image prompt and diminish diversity.
Activate IP-Adapter only in the style layer and then call the pipeline again.
```py
scale = {
"up": {"block_0": [0.0, 1.0, 0.0]},
}
pipeline.set_ip_adapter_scale(scale)
generator = torch.Generator(device="cpu").manual_seed(42)
image = pipeline(
prompt="a cat, masterpiece, best quality, high quality",
image=style_image,
negative_prompt="text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry",
guidance_scale=5,
num_inference_steps=30,
generator=generator,
).images[0]
image
```
<div class="flex flex-row gap-4">
<div class="flex-1">
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit_style_cat.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">IP-Adapter only in style layer</figcaption>
</div>
<div class="flex-1">
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/30518dfe089e6bf50008875077b44cb98fb2065c/diffusers/default_out.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">IP-Adapter in all layers</figcaption>
</div>
</div>
Note that you don't have to specify all layers in the dictionary. Those not included in the dictionary will be set to scale 0 which means disable IP-Adapter by default.
...@@ -28,6 +28,7 @@ from ..utils import ( ...@@ -28,6 +28,7 @@ from ..utils import (
is_transformers_available, is_transformers_available,
logging, logging,
) )
from .unet_loader_utils import _maybe_expand_lora_scales
if is_transformers_available(): if is_transformers_available():
...@@ -243,25 +244,55 @@ class IPAdapterMixin: ...@@ -243,25 +244,55 @@ class IPAdapterMixin:
def set_ip_adapter_scale(self, scale): def set_ip_adapter_scale(self, scale):
""" """
Sets the conditioning scale between text and image. Set IP-Adapter scales per-transformer block. Input `scale` could be a single config or a list of configs for
granular control over each IP-Adapter behavior. A config can be a float or a dictionary.
Example: Example:
```py ```py
pipeline.set_ip_adapter_scale(0.5) # To use original IP-Adapter
scale = 1.0
pipeline.set_ip_adapter_scale(scale)
# To use style block only
scale = {
"up": {"block_0": [0.0, 1.0, 0.0]},
}
pipeline.set_ip_adapter_scale(scale)
# To use style+layout blocks
scale = {
"down": {"block_2": [0.0, 1.0]},
"up": {"block_0": [0.0, 1.0, 0.0]},
}
pipeline.set_ip_adapter_scale(scale)
# To use style and layout from 2 reference images
scales = [{"down": {"block_2": [0.0, 1.0]}}, {"up": {"block_0": [0.0, 1.0, 0.0]}}]
pipeline.set_ip_adapter_scale(scales)
``` ```
""" """
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
for attn_processor in unet.attn_processors.values(): if not isinstance(scale, list):
scale = [scale]
scale_configs = _maybe_expand_lora_scales(unet, scale, default_scale=0.0)
for attn_name, attn_processor in unet.attn_processors.items():
if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)): if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)):
if not isinstance(scale, list): if len(scale_configs) != len(attn_processor.scale):
scale = [scale] * len(attn_processor.scale)
if len(attn_processor.scale) != len(scale):
raise ValueError( raise ValueError(
f"`scale` should be a list of same length as the number if ip-adapters " f"Cannot assign {len(scale_configs)} scale_configs to "
f"Expected {len(attn_processor.scale)} but got {len(scale)}." f"{len(attn_processor.scale)} IP-Adapter."
) )
attn_processor.scale = scale elif len(scale_configs) == 1:
scale_configs = scale_configs * len(attn_processor.scale)
for i, scale_config in enumerate(scale_configs):
if isinstance(scale_config, dict):
for k, s in scale_config.items():
if attn_name.startswith(k):
attn_processor.scale[i] = s
else:
attn_processor.scale[i] = scale_config
def unload_ip_adapter(self): def unload_ip_adapter(self):
""" """
......
...@@ -38,7 +38,9 @@ def _translate_into_actual_layer_name(name): ...@@ -38,7 +38,9 @@ def _translate_into_actual_layer_name(name):
return ".".join((updown, block, attn)) return ".".join((updown, block, attn))
def _maybe_expand_lora_scales(unet: "UNet2DConditionModel", weight_scales: List[Union[float, Dict]]): def _maybe_expand_lora_scales(
unet: "UNet2DConditionModel", weight_scales: List[Union[float, Dict]], default_scale=1.0
):
blocks_with_transformer = { blocks_with_transformer = {
"down": [i for i, block in enumerate(unet.down_blocks) if hasattr(block, "attentions")], "down": [i for i, block in enumerate(unet.down_blocks) if hasattr(block, "attentions")],
"up": [i for i, block in enumerate(unet.up_blocks) if hasattr(block, "attentions")], "up": [i for i, block in enumerate(unet.up_blocks) if hasattr(block, "attentions")],
...@@ -47,7 +49,11 @@ def _maybe_expand_lora_scales(unet: "UNet2DConditionModel", weight_scales: List[ ...@@ -47,7 +49,11 @@ def _maybe_expand_lora_scales(unet: "UNet2DConditionModel", weight_scales: List[
expanded_weight_scales = [ expanded_weight_scales = [
_maybe_expand_lora_scales_for_one_adapter( _maybe_expand_lora_scales_for_one_adapter(
weight_for_adapter, blocks_with_transformer, transformer_per_block, unet.state_dict() weight_for_adapter,
blocks_with_transformer,
transformer_per_block,
unet.state_dict(),
default_scale=default_scale,
) )
for weight_for_adapter in weight_scales for weight_for_adapter in weight_scales
] ]
...@@ -60,6 +66,7 @@ def _maybe_expand_lora_scales_for_one_adapter( ...@@ -60,6 +66,7 @@ def _maybe_expand_lora_scales_for_one_adapter(
blocks_with_transformer: Dict[str, int], blocks_with_transformer: Dict[str, int],
transformer_per_block: Dict[str, int], transformer_per_block: Dict[str, int],
state_dict: None, state_dict: None,
default_scale: float = 1.0,
): ):
""" """
Expands the inputs into a more granular dictionary. See the example below for more details. Expands the inputs into a more granular dictionary. See the example below for more details.
...@@ -108,21 +115,36 @@ def _maybe_expand_lora_scales_for_one_adapter( ...@@ -108,21 +115,36 @@ def _maybe_expand_lora_scales_for_one_adapter(
scales = copy.deepcopy(scales) scales = copy.deepcopy(scales)
if "mid" not in scales: if "mid" not in scales:
scales["mid"] = 1 scales["mid"] = default_scale
elif isinstance(scales["mid"], list):
if len(scales["mid"]) == 1:
scales["mid"] = scales["mid"][0]
else:
raise ValueError(f"Expected 1 scales for mid, got {len(scales['mid'])}.")
for updown in ["up", "down"]: for updown in ["up", "down"]:
if updown not in scales: if updown not in scales:
scales[updown] = 1 scales[updown] = default_scale
# eg {"down": 1} to {"down": {"block_1": 1, "block_2": 1}}} # eg {"down": 1} to {"down": {"block_1": 1, "block_2": 1}}}
if not isinstance(scales[updown], dict): if not isinstance(scales[updown], dict):
scales[updown] = {f"block_{i}": scales[updown] for i in blocks_with_transformer[updown]} scales[updown] = {f"block_{i}": copy.deepcopy(scales[updown]) for i in blocks_with_transformer[updown]}
# eg {"down": "block_1": 1}} to {"down": "block_1": [1, 1]}} # eg {"down": {"block_1": 1}} to {"down": {"block_1": [1, 1]}}
for i in blocks_with_transformer[updown]: for i in blocks_with_transformer[updown]:
block = f"block_{i}" block = f"block_{i}"
# set not assigned blocks to default scale
if block not in scales[updown]:
scales[updown][block] = default_scale
if not isinstance(scales[updown][block], list): if not isinstance(scales[updown][block], list):
scales[updown][block] = [scales[updown][block] for _ in range(transformer_per_block[updown])] scales[updown][block] = [scales[updown][block] for _ in range(transformer_per_block[updown])]
elif len(scales[updown][block]) == 1:
# a list specifying scale to each masked IP input
scales[updown][block] = scales[updown][block] * transformer_per_block[updown]
elif len(scales[updown][block]) != transformer_per_block[updown]:
raise ValueError(
f"Expected {transformer_per_block[updown]} scales for {updown}.{block}, got {len(scales[updown][block])}."
)
# eg {"down": "block_1": [1, 1]}} to {"down.block_1.0": 1, "down.block_1.1": 1} # eg {"down": "block_1": [1, 1]}} to {"down.block_1.0": 1, "down.block_1.1": 1}
for i in blocks_with_transformer[updown]: for i in blocks_with_transformer[updown]:
......
...@@ -2229,44 +2229,51 @@ class IPAdapterAttnProcessor(nn.Module): ...@@ -2229,44 +2229,51 @@ class IPAdapterAttnProcessor(nn.Module):
for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip( for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
): ):
if mask is not None: skip = False
if not isinstance(scale, list): if isinstance(scale, list):
scale = [scale] if all(s == 0 for s in scale):
skip = True
elif scale == 0:
skip = True
if not skip:
if mask is not None:
if not isinstance(scale, list):
scale = [scale] * mask.shape[1]
current_num_images = mask.shape[1]
for i in range(current_num_images):
ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :])
ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :])
ip_key = attn.head_to_batch_dim(ip_key)
ip_value = attn.head_to_batch_dim(ip_value)
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
_current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
_current_ip_hidden_states = attn.batch_to_head_dim(_current_ip_hidden_states)
mask_downsample = IPAdapterMaskProcessor.downsample(
mask[:, i, :, :],
batch_size,
_current_ip_hidden_states.shape[1],
_current_ip_hidden_states.shape[2],
)
mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
current_num_images = mask.shape[1] hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample)
for i in range(current_num_images): else:
ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :]) ip_key = to_k_ip(current_ip_hidden_states)
ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :]) ip_value = to_v_ip(current_ip_hidden_states)
ip_key = attn.head_to_batch_dim(ip_key) ip_key = attn.head_to_batch_dim(ip_key)
ip_value = attn.head_to_batch_dim(ip_value) ip_value = attn.head_to_batch_dim(ip_value)
ip_attention_probs = attn.get_attention_scores(query, ip_key, None) ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
_current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
_current_ip_hidden_states = attn.batch_to_head_dim(_current_ip_hidden_states) current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states)
mask_downsample = IPAdapterMaskProcessor.downsample(
mask[:, i, :, :],
batch_size,
_current_ip_hidden_states.shape[1],
_current_ip_hidden_states.shape[2],
)
mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample)
else:
ip_key = to_k_ip(current_ip_hidden_states)
ip_value = to_v_ip(current_ip_hidden_states)
ip_key = attn.head_to_batch_dim(ip_key) hidden_states = hidden_states + scale * current_ip_hidden_states
ip_value = attn.head_to_batch_dim(ip_value)
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states)
hidden_states = hidden_states + scale * current_ip_hidden_states
# linear proj # linear proj
hidden_states = attn.to_out[0](hidden_states) hidden_states = attn.to_out[0](hidden_states)
...@@ -2439,57 +2446,64 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module): ...@@ -2439,57 +2446,64 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip( for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
): ):
if mask is not None: skip = False
if not isinstance(scale, list): if isinstance(scale, list):
scale = [scale] if all(s == 0 for s in scale):
skip = True
elif scale == 0:
skip = True
if not skip:
if mask is not None:
if not isinstance(scale, list):
scale = [scale] * mask.shape[1]
current_num_images = mask.shape[1]
for i in range(current_num_images):
ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :])
ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :])
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
_current_ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
_current_ip_hidden_states = _current_ip_hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)
_current_ip_hidden_states = _current_ip_hidden_states.to(query.dtype)
current_num_images = mask.shape[1] mask_downsample = IPAdapterMaskProcessor.downsample(
for i in range(current_num_images): mask[:, i, :, :],
ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :]) batch_size,
ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :]) _current_ip_hidden_states.shape[1],
_current_ip_hidden_states.shape[2],
)
mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample)
else:
ip_key = to_k_ip(current_ip_hidden_states)
ip_value = to_v_ip(current_ip_hidden_states)
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim) # the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1 # TODO: add support for attn.scale when we move to Torch 2.1
_current_ip_hidden_states = F.scaled_dot_product_attention( current_ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
) )
_current_ip_hidden_states = _current_ip_hidden_states.transpose(1, 2).reshape( current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim batch_size, -1, attn.heads * head_dim
) )
_current_ip_hidden_states = _current_ip_hidden_states.to(query.dtype) current_ip_hidden_states = current_ip_hidden_states.to(query.dtype)
mask_downsample = IPAdapterMaskProcessor.downsample(
mask[:, i, :, :],
batch_size,
_current_ip_hidden_states.shape[1],
_current_ip_hidden_states.shape[2],
)
mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample)
else:
ip_key = to_k_ip(current_ip_hidden_states)
ip_value = to_v_ip(current_ip_hidden_states)
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
current_ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)
current_ip_hidden_states = current_ip_hidden_states.to(query.dtype)
hidden_states = hidden_states + scale * current_ip_hidden_states hidden_states = hidden_states + scale * current_ip_hidden_states
# linear proj # linear proj
hidden_states = attn.to_out[0](hidden_states) hidden_states = attn.to_out[0](hidden_states)
......
...@@ -73,7 +73,9 @@ class IPAdapterNightlyTestsMixin(unittest.TestCase): ...@@ -73,7 +73,9 @@ class IPAdapterNightlyTestsMixin(unittest.TestCase):
image_processor = CLIPImageProcessor.from_pretrained(repo_id) image_processor = CLIPImageProcessor.from_pretrained(repo_id)
return image_processor return image_processor
def get_dummy_inputs(self, for_image_to_image=False, for_inpainting=False, for_sdxl=False, for_masks=False): def get_dummy_inputs(
self, for_image_to_image=False, for_inpainting=False, for_sdxl=False, for_masks=False, for_instant_style=False
):
image = load_image( image = load_image(
"https://user-images.githubusercontent.com/24734142/266492875-2d50d223-8475-44f0-a7c6-08b51cb53572.png" "https://user-images.githubusercontent.com/24734142/266492875-2d50d223-8475-44f0-a7c6-08b51cb53572.png"
) )
...@@ -126,6 +128,40 @@ class IPAdapterNightlyTestsMixin(unittest.TestCase): ...@@ -126,6 +128,40 @@ class IPAdapterNightlyTestsMixin(unittest.TestCase):
} }
) )
elif for_instant_style:
composition_mask = load_image(
"https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/1024_whole_mask.png"
)
female_mask = load_image(
"https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/ip_adapter_None_20240321125641_mask.png"
)
male_mask = load_image(
"https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/ip_adapter_None_20240321125344_mask.png"
)
background_mask = load_image(
"https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/ip_adapter_6_20240321130722_mask.png"
)
ip_composition_image = load_image(
"https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/ip_adapter__20240321125152.png"
)
ip_female_style = load_image(
"https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/ip_adapter__20240321125625.png"
)
ip_male_style = load_image(
"https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/ip_adapter__20240321125329.png"
)
ip_background = load_image(
"https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/ip_adapter__20240321130643.png"
)
input_kwargs.update(
{
"ip_adapter_image": [ip_composition_image, [ip_female_style, ip_male_style, ip_background]],
"cross_attention_kwargs": {
"ip_adapter_masks": [[composition_mask], [female_mask, male_mask, background_mask]]
},
}
)
return input_kwargs return input_kwargs
...@@ -575,6 +611,48 @@ class IPAdapterSDXLIntegrationTests(IPAdapterNightlyTestsMixin): ...@@ -575,6 +611,48 @@ class IPAdapterSDXLIntegrationTests(IPAdapterNightlyTestsMixin):
max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice) max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
assert max_diff < 5e-4 assert max_diff < 5e-4
def test_instant_style_multiple_masks(self):
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
"h94/IP-Adapter", subfolder="models/image_encoder", torch_dtype=torch.float16
).to("cuda")
pipeline = StableDiffusionXLPipeline.from_pretrained(
"RunDiffusion/Juggernaut-XL-v9", torch_dtype=torch.float16, image_encoder=image_encoder, variant="fp16"
).to("cuda")
pipeline.enable_model_cpu_offload()
pipeline.load_ip_adapter(
["ostris/ip-composition-adapter", "h94/IP-Adapter"],
subfolder=["", "sdxl_models"],
weight_name=[
"ip_plus_composition_sdxl.safetensors",
"ip-adapter_sdxl_vit-h.safetensors",
],
image_encoder_folder=None,
)
scale_1 = {
"down": [[0.0, 0.0, 1.0]],
"mid": [[0.0, 0.0, 1.0]],
"up": {"block_0": [[0.0, 0.0, 1.0], [1.0, 1.0, 1.0], [0.0, 0.0, 1.0]], "block_1": [[0.0, 0.0, 1.0]]},
}
pipeline.set_ip_adapter_scale([1.0, scale_1])
inputs = self.get_dummy_inputs(for_instant_style=True)
processor = IPAdapterMaskProcessor()
masks1 = inputs["cross_attention_kwargs"]["ip_adapter_masks"][0]
masks2 = inputs["cross_attention_kwargs"]["ip_adapter_masks"][1]
masks1 = processor.preprocess(masks1, height=1024, width=1024)
masks2 = processor.preprocess(masks2, height=1024, width=1024)
masks2 = masks2.reshape(1, masks2.shape[0], masks2.shape[2], masks2.shape[3])
inputs["cross_attention_kwargs"]["ip_adapter_masks"] = [masks1, masks2]
images = pipeline(**inputs).images
image_slice = images[0, :3, :3, -1].flatten()
expected_slice = np.array(
[0.23551631, 0.20476806, 0.14099443, 0.0, 0.07675594, 0.05672678, 0.0, 0.0, 0.02099729]
)
max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
assert max_diff < 5e-4
def test_ip_adapter_multiple_masks_one_adapter(self): def test_ip_adapter_multiple_masks_one_adapter(self):
image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder") image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder")
pipeline = StableDiffusionXLPipeline.from_pretrained( pipeline = StableDiffusionXLPipeline.from_pretrained(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment