Unverified Commit a0cf6076 authored by Fabio Rigano's avatar Fabio Rigano Committed by GitHub
Browse files

Multi-image masking for single IP Adapter (#7499)



* Support multiimage masking

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
parent a341b536
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import inspect import inspect
from importlib import import_module from importlib import import_module
from typing import Callable, Optional, Union from typing import Callable, List, Optional, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -2195,15 +2195,33 @@ class IPAdapterAttnProcessor(nn.Module): ...@@ -2195,15 +2195,33 @@ class IPAdapterAttnProcessor(nn.Module):
hidden_states = attn.batch_to_head_dim(hidden_states) hidden_states = attn.batch_to_head_dim(hidden_states)
if ip_adapter_masks is not None: if ip_adapter_masks is not None:
if not isinstance(ip_adapter_masks, torch.Tensor) or ip_adapter_masks.ndim != 4: if not isinstance(ip_adapter_masks, List):
# for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width]
ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1))
if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)):
raise ValueError( raise ValueError(
" ip_adapter_mask should be a tensor with shape [num_ip_adapter, 1, height, width]." f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match "
" Please use `IPAdapterMaskProcessor` to preprocess your mask" f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states "
) f"({len(ip_hidden_states)})"
if len(ip_adapter_masks) != len(self.scale):
raise ValueError(
f"Number of ip_adapter_masks ({len(ip_adapter_masks)}) must match number of IP-Adapters ({len(self.scale)})"
) )
else:
for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)):
if not isinstance(mask, torch.Tensor) or mask.ndim != 4:
raise ValueError(
"Each element of the ip_adapter_masks array should be a tensor with shape "
"[1, num_images_for_ip_adapter, height, width]."
" Please use `IPAdapterMaskProcessor` to preprocess your mask"
)
if mask.shape[1] != ip_state.shape[1]:
raise ValueError(
f"Number of masks ({mask.shape[1]}) does not match "
f"number of ip images ({ip_state.shape[1]}) at index {index}"
)
if isinstance(scale, list) and not len(scale) == mask.shape[1]:
raise ValueError(
f"Number of masks ({mask.shape[1]}) does not match "
f"number of scales ({len(scale)}) at index {index}"
)
else: else:
ip_adapter_masks = [None] * len(self.scale) ip_adapter_masks = [None] * len(self.scale)
...@@ -2211,26 +2229,44 @@ class IPAdapterAttnProcessor(nn.Module): ...@@ -2211,26 +2229,44 @@ class IPAdapterAttnProcessor(nn.Module):
for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip( for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
): ):
ip_key = to_k_ip(current_ip_hidden_states) if mask is not None:
ip_value = to_v_ip(current_ip_hidden_states) if not isinstance(scale, list):
scale = [scale]
ip_key = attn.head_to_batch_dim(ip_key)
ip_value = attn.head_to_batch_dim(ip_value) current_num_images = mask.shape[1]
for i in range(current_num_images):
ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :])
ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :])
ip_key = attn.head_to_batch_dim(ip_key)
ip_value = attn.head_to_batch_dim(ip_value)
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
_current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
_current_ip_hidden_states = attn.batch_to_head_dim(_current_ip_hidden_states)
mask_downsample = IPAdapterMaskProcessor.downsample(
mask[:, i, :, :],
batch_size,
_current_ip_hidden_states.shape[1],
_current_ip_hidden_states.shape[2],
)
ip_attention_probs = attn.get_attention_scores(query, ip_key, None) mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states)
if mask is not None: hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample)
mask_downsample = IPAdapterMaskProcessor.downsample( else:
mask, batch_size, current_ip_hidden_states.shape[1], current_ip_hidden_states.shape[2] ip_key = to_k_ip(current_ip_hidden_states)
) ip_value = to_v_ip(current_ip_hidden_states)
mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device) ip_key = attn.head_to_batch_dim(ip_key)
ip_value = attn.head_to_batch_dim(ip_value)
current_ip_hidden_states = current_ip_hidden_states * mask_downsample ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states)
hidden_states = hidden_states + scale * current_ip_hidden_states hidden_states = hidden_states + scale * current_ip_hidden_states
# linear proj # linear proj
hidden_states = attn.to_out[0](hidden_states) hidden_states = attn.to_out[0](hidden_states)
...@@ -2369,15 +2405,33 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module): ...@@ -2369,15 +2405,33 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
hidden_states = hidden_states.to(query.dtype) hidden_states = hidden_states.to(query.dtype)
if ip_adapter_masks is not None: if ip_adapter_masks is not None:
if not isinstance(ip_adapter_masks, torch.Tensor) or ip_adapter_masks.ndim != 4: if not isinstance(ip_adapter_masks, List):
raise ValueError( # for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width]
" ip_adapter_mask should be a tensor with shape [num_ip_adapter, 1, height, width]." ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1))
" Please use `IPAdapterMaskProcessor` to preprocess your mask" if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)):
)
if len(ip_adapter_masks) != len(self.scale):
raise ValueError( raise ValueError(
f"Number of ip_adapter_masks ({len(ip_adapter_masks)}) must match number of IP-Adapters ({len(self.scale)})" f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match "
f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states "
f"({len(ip_hidden_states)})"
) )
else:
for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)):
if not isinstance(mask, torch.Tensor) or mask.ndim != 4:
raise ValueError(
"Each element of the ip_adapter_masks array should be a tensor with shape "
"[1, num_images_for_ip_adapter, height, width]."
" Please use `IPAdapterMaskProcessor` to preprocess your mask"
)
if mask.shape[1] != ip_state.shape[1]:
raise ValueError(
f"Number of masks ({mask.shape[1]}) does not match "
f"number of ip images ({ip_state.shape[1]}) at index {index}"
)
if isinstance(scale, list) and not len(scale) == mask.shape[1]:
raise ValueError(
f"Number of masks ({mask.shape[1]}) does not match "
f"number of scales ({len(scale)}) at index {index}"
)
else: else:
ip_adapter_masks = [None] * len(self.scale) ip_adapter_masks = [None] * len(self.scale)
...@@ -2385,33 +2439,57 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module): ...@@ -2385,33 +2439,57 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip( for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
): ):
ip_key = to_k_ip(current_ip_hidden_states) if mask is not None:
ip_value = to_v_ip(current_ip_hidden_states) if not isinstance(scale, list):
scale = [scale]
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) current_num_images = mask.shape[1]
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) for i in range(current_num_images):
ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :])
ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :])
# the output of sdp = (batch, num_heads, seq_len, head_dim) ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# TODO: add support for attn.scale when we move to Torch 2.1 ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
current_ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape( # the output of sdp = (batch, num_heads, seq_len, head_dim)
batch_size, -1, attn.heads * head_dim # TODO: add support for attn.scale when we move to Torch 2.1
) _current_ip_hidden_states = F.scaled_dot_product_attention(
current_ip_hidden_states = current_ip_hidden_states.to(query.dtype) query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
if mask is not None: _current_ip_hidden_states = _current_ip_hidden_states.transpose(1, 2).reshape(
mask_downsample = IPAdapterMaskProcessor.downsample( batch_size, -1, attn.heads * head_dim
mask, batch_size, current_ip_hidden_states.shape[1], current_ip_hidden_states.shape[2] )
) _current_ip_hidden_states = _current_ip_hidden_states.to(query.dtype)
mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device) mask_downsample = IPAdapterMaskProcessor.downsample(
mask[:, i, :, :],
batch_size,
_current_ip_hidden_states.shape[1],
_current_ip_hidden_states.shape[2],
)
current_ip_hidden_states = current_ip_hidden_states * mask_downsample mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample)
else:
ip_key = to_k_ip(current_ip_hidden_states)
ip_value = to_v_ip(current_ip_hidden_states)
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
current_ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)
current_ip_hidden_states = current_ip_hidden_states.to(query.dtype)
hidden_states = hidden_states + scale * current_ip_hidden_states hidden_states = hidden_states + scale * current_ip_hidden_states
# linear proj # linear proj
hidden_states = attn.to_out[0](hidden_states) hidden_states = attn.to_out[0](hidden_states)
......
...@@ -544,3 +544,33 @@ class IPAdapterSDXLIntegrationTests(IPAdapterNightlyTestsMixin): ...@@ -544,3 +544,33 @@ class IPAdapterSDXLIntegrationTests(IPAdapterNightlyTestsMixin):
max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice) max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
assert max_diff < 5e-4 assert max_diff < 5e-4
def test_ip_adapter_multiple_masks_one_adapter(self):
image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder")
pipeline = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
image_encoder=image_encoder,
torch_dtype=self.dtype,
)
pipeline.enable_model_cpu_offload()
pipeline.load_ip_adapter(
"h94/IP-Adapter", subfolder="sdxl_models", weight_name=["ip-adapter-plus-face_sdxl_vit-h.safetensors"]
)
pipeline.set_ip_adapter_scale([[0.7, 0.7]])
inputs = self.get_dummy_inputs(for_masks=True)
masks = inputs["cross_attention_kwargs"]["ip_adapter_masks"]
processor = IPAdapterMaskProcessor()
masks = processor.preprocess(masks)
masks = masks.reshape(1, masks.shape[0], masks.shape[2], masks.shape[3])
inputs["cross_attention_kwargs"]["ip_adapter_masks"] = [masks]
ip_images = inputs["ip_adapter_image"]
inputs["ip_adapter_image"] = [[image[0] for image in ip_images]]
images = pipeline(**inputs).images
image_slice = images[0, :3, :3, -1].flatten()
expected_slice = np.array(
[0.79474676, 0.7977683, 0.8013954, 0.7988008, 0.7970615, 0.8029355, 0.80614823, 0.8050743, 0.80627424]
)
max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
assert max_diff < 5e-4
...@@ -238,6 +238,11 @@ class IPAdapterTesterMixin: ...@@ -238,6 +238,11 @@ class IPAdapterTesterMixin:
def _get_dummy_image_embeds(self, cross_attention_dim: int = 32): def _get_dummy_image_embeds(self, cross_attention_dim: int = 32):
return torch.randn((2, 1, cross_attention_dim), device=torch_device) return torch.randn((2, 1, cross_attention_dim), device=torch_device)
def _get_dummy_masks(self, input_size: int = 64):
_masks = torch.zeros((1, 1, input_size, input_size), device=torch_device)
_masks[0, :, :, : int(input_size / 2)] = 1
return _masks
def _modify_inputs_for_ip_adapter_test(self, inputs: Dict[str, Any]): def _modify_inputs_for_ip_adapter_test(self, inputs: Dict[str, Any]):
parameters = inspect.signature(self.pipeline_class.__call__).parameters parameters = inspect.signature(self.pipeline_class.__call__).parameters
if "image" in parameters.keys() and "strength" in parameters.keys(): if "image" in parameters.keys() and "strength" in parameters.keys():
...@@ -365,6 +370,51 @@ class IPAdapterTesterMixin: ...@@ -365,6 +370,51 @@ class IPAdapterTesterMixin:
assert out_cfg.shape == out_no_cfg.shape assert out_cfg.shape == out_no_cfg.shape
def test_ip_adapter_masks(self, expected_max_diff: float = 1e-4):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components).to(torch_device)
pipe.set_progress_bar_config(disable=None)
cross_attention_dim = pipe.unet.config.get("cross_attention_dim", 32)
sample_size = pipe.unet.config.get("sample_size", 32)
block_out_channels = pipe.vae.config.get("block_out_channels", [128, 256, 512, 512])
input_size = sample_size * (2 ** (len(block_out_channels) - 1))
# forward pass without ip adapter
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
output_without_adapter = pipe(**inputs)[0]
output_without_adapter = output_without_adapter[0, -3:, -3:, -1].flatten()
adapter_state_dict = create_ip_adapter_state_dict(pipe.unet)
pipe.unet._load_ip_adapter_weights(adapter_state_dict)
# forward pass with single ip adapter and masks, but scale=0 which should have no effect
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
inputs["cross_attention_kwargs"] = {"ip_adapter_masks": [self._get_dummy_masks(input_size)]}
pipe.set_ip_adapter_scale(0.0)
output_without_adapter_scale = pipe(**inputs)[0]
output_without_adapter_scale = output_without_adapter_scale[0, -3:, -3:, -1].flatten()
# forward pass with single ip adapter and masks, but with scale of adapter weights
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
inputs["cross_attention_kwargs"] = {"ip_adapter_masks": [self._get_dummy_masks(input_size)]}
pipe.set_ip_adapter_scale(42.0)
output_with_adapter_scale = pipe(**inputs)[0]
output_with_adapter_scale = output_with_adapter_scale[0, -3:, -3:, -1].flatten()
max_diff_without_adapter_scale = np.abs(output_without_adapter_scale - output_without_adapter).max()
max_diff_with_adapter_scale = np.abs(output_with_adapter_scale - output_without_adapter).max()
self.assertLess(
max_diff_without_adapter_scale,
expected_max_diff,
"Output without ip-adapter must be same as normal inference",
)
self.assertGreater(
max_diff_with_adapter_scale, 1e-3, "Output with ip-adapter must be different from normal inference"
)
class PipelineLatentTesterMixin: class PipelineLatentTesterMixin:
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment