Unverified Commit 13001ee3 authored by Fabio Rigano's avatar Fabio Rigano Committed by GitHub
Browse files

Bugfix in IPAdapterFaceID (#6835)

parent 65329aed
...@@ -104,6 +104,22 @@ class LoRAIPAdapterAttnProcessor(nn.Module): ...@@ -104,6 +104,22 @@ class LoRAIPAdapterAttnProcessor(nn.Module):
): ):
residual = hidden_states residual = hidden_states
# separate ip_hidden_states from encoder_hidden_states
if encoder_hidden_states is not None:
if isinstance(encoder_hidden_states, tuple):
encoder_hidden_states, ip_hidden_states = encoder_hidden_states
else:
deprecation_message = (
"You have passed a tensor as `encoder_hidden_states`.This is deprecated and will be removed in a future release."
" Please make sure to update your script to pass `encoder_hidden_states` as a tuple to supress this warning."
)
deprecate("encoder_hidden_states not a tuple", "1.0.0", deprecation_message, standard_warn=False)
end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0]
encoder_hidden_states, ip_hidden_states = (
encoder_hidden_states[:, :end_pos, :],
[encoder_hidden_states[:, end_pos:, :]],
)
if attn.spatial_norm is not None: if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb) hidden_states = attn.spatial_norm(hidden_states, temb)
...@@ -125,14 +141,7 @@ class LoRAIPAdapterAttnProcessor(nn.Module): ...@@ -125,14 +141,7 @@ class LoRAIPAdapterAttnProcessor(nn.Module):
if encoder_hidden_states is None: if encoder_hidden_states is None:
encoder_hidden_states = hidden_states encoder_hidden_states = hidden_states
else: elif attn.norm_cross:
# get encoder_hidden_states, ip_hidden_states
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
encoder_hidden_states, ip_hidden_states = (
encoder_hidden_states[:, :end_pos, :],
encoder_hidden_states[:, end_pos:, :],
)
if attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states)
...@@ -233,6 +242,22 @@ class LoRAIPAdapterAttnProcessor2_0(nn.Module): ...@@ -233,6 +242,22 @@ class LoRAIPAdapterAttnProcessor2_0(nn.Module):
): ):
residual = hidden_states residual = hidden_states
# separate ip_hidden_states from encoder_hidden_states
if encoder_hidden_states is not None:
if isinstance(encoder_hidden_states, tuple):
encoder_hidden_states, ip_hidden_states = encoder_hidden_states
else:
deprecation_message = (
"You have passed a tensor as `encoder_hidden_states`.This is deprecated and will be removed in a future release."
" Please make sure to update your script to pass `encoder_hidden_states` as a tuple to supress this warning."
)
deprecate("encoder_hidden_states not a tuple", "1.0.0", deprecation_message, standard_warn=False)
end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0]
encoder_hidden_states, ip_hidden_states = (
encoder_hidden_states[:, :end_pos, :],
[encoder_hidden_states[:, end_pos:, :]],
)
if attn.spatial_norm is not None: if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb) hidden_states = attn.spatial_norm(hidden_states, temb)
...@@ -259,14 +284,7 @@ class LoRAIPAdapterAttnProcessor2_0(nn.Module): ...@@ -259,14 +284,7 @@ class LoRAIPAdapterAttnProcessor2_0(nn.Module):
if encoder_hidden_states is None: if encoder_hidden_states is None:
encoder_hidden_states = hidden_states encoder_hidden_states = hidden_states
else: elif attn.norm_cross:
# get encoder_hidden_states, ip_hidden_states
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
encoder_hidden_states, ip_hidden_states = (
encoder_hidden_states[:, :end_pos, :],
encoder_hidden_states[:, end_pos:, :],
)
if attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states)
...@@ -951,30 +969,6 @@ class IPAdapterFaceIDStableDiffusionPipeline( ...@@ -951,30 +969,6 @@ class IPAdapterFaceIDStableDiffusionPipeline(
return prompt_embeds, negative_prompt_embeds return prompt_embeds, negative_prompt_embeds
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype)
if output_hidden_states:
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_enc_hidden_states = self.image_encoder(
torch.zeros_like(image), output_hidden_states=True
).hidden_states[-2]
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
num_images_per_prompt, dim=0
)
return image_enc_hidden_states, uncond_image_enc_hidden_states
else:
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_embeds = torch.zeros_like(image_embeds)
return image_embeds, uncond_image_embeds
def run_safety_checker(self, image, device, dtype): def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None: if self.safety_checker is None:
has_nsfw_concept = None has_nsfw_concept = None
...@@ -1302,7 +1296,6 @@ class IPAdapterFaceIDStableDiffusionPipeline( ...@@ -1302,7 +1296,6 @@ class IPAdapterFaceIDStableDiffusionPipeline(
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
image_embeds (`torch.FloatTensor`, *optional*): image_embeds (`torch.FloatTensor`, *optional*):
Pre-generated image embeddings. Pre-generated image embeddings.
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
output_type (`str`, *optional*, defaults to `"pil"`): output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`. The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
...@@ -1411,7 +1404,7 @@ class IPAdapterFaceIDStableDiffusionPipeline( ...@@ -1411,7 +1404,7 @@ class IPAdapterFaceIDStableDiffusionPipeline(
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
if image_embeds is not None: if image_embeds is not None:
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0).to( image_embeds = torch.stack([image_embeds] * num_images_per_prompt, dim=0).to(
device=device, dtype=prompt_embeds.dtype device=device, dtype=prompt_embeds.dtype
) )
negative_image_embeds = torch.zeros_like(image_embeds) negative_image_embeds = torch.zeros_like(image_embeds)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment