Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
13001ee3
Unverified
Commit
13001ee3
authored
Feb 03, 2024
by
Fabio Rigano
Committed by
GitHub
Feb 03, 2024
Browse files
Bugfix in IPAdapterFaceID (#6835)
parent
65329aed
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
37 additions
and
44 deletions
+37
-44
examples/community/ip_adapter_face_id.py
examples/community/ip_adapter_face_id.py
+37
-44
No files found.
examples/community/ip_adapter_face_id.py
View file @
13001ee3
...
@@ -104,6 +104,22 @@ class LoRAIPAdapterAttnProcessor(nn.Module):
...
@@ -104,6 +104,22 @@ class LoRAIPAdapterAttnProcessor(nn.Module):
):
):
residual
=
hidden_states
residual
=
hidden_states
# separate ip_hidden_states from encoder_hidden_states
if
encoder_hidden_states
is
not
None
:
if
isinstance
(
encoder_hidden_states
,
tuple
):
encoder_hidden_states
,
ip_hidden_states
=
encoder_hidden_states
else
:
deprecation_message
=
(
"You have passed a tensor as `encoder_hidden_states`.This is deprecated and will be removed in a future release."
" Please make sure to update your script to pass `encoder_hidden_states` as a tuple to supress this warning."
)
deprecate
(
"encoder_hidden_states not a tuple"
,
"1.0.0"
,
deprecation_message
,
standard_warn
=
False
)
end_pos
=
encoder_hidden_states
.
shape
[
1
]
-
self
.
num_tokens
[
0
]
encoder_hidden_states
,
ip_hidden_states
=
(
encoder_hidden_states
[:,
:
end_pos
,
:],
[
encoder_hidden_states
[:,
end_pos
:,
:]],
)
if
attn
.
spatial_norm
is
not
None
:
if
attn
.
spatial_norm
is
not
None
:
hidden_states
=
attn
.
spatial_norm
(
hidden_states
,
temb
)
hidden_states
=
attn
.
spatial_norm
(
hidden_states
,
temb
)
...
@@ -125,15 +141,8 @@ class LoRAIPAdapterAttnProcessor(nn.Module):
...
@@ -125,15 +141,8 @@ class LoRAIPAdapterAttnProcessor(nn.Module):
if
encoder_hidden_states
is
None
:
if
encoder_hidden_states
is
None
:
encoder_hidden_states
=
hidden_states
encoder_hidden_states
=
hidden_states
else
:
elif
attn
.
norm_cross
:
# get encoder_hidden_states, ip_hidden_states
encoder_hidden_states
=
attn
.
norm_encoder_hidden_states
(
encoder_hidden_states
)
end_pos
=
encoder_hidden_states
.
shape
[
1
]
-
self
.
num_tokens
encoder_hidden_states
,
ip_hidden_states
=
(
encoder_hidden_states
[:,
:
end_pos
,
:],
encoder_hidden_states
[:,
end_pos
:,
:],
)
if
attn
.
norm_cross
:
encoder_hidden_states
=
attn
.
norm_encoder_hidden_states
(
encoder_hidden_states
)
key
=
attn
.
to_k
(
encoder_hidden_states
)
+
self
.
lora_scale
*
self
.
to_k_lora
(
encoder_hidden_states
)
key
=
attn
.
to_k
(
encoder_hidden_states
)
+
self
.
lora_scale
*
self
.
to_k_lora
(
encoder_hidden_states
)
value
=
attn
.
to_v
(
encoder_hidden_states
)
+
self
.
lora_scale
*
self
.
to_v_lora
(
encoder_hidden_states
)
value
=
attn
.
to_v
(
encoder_hidden_states
)
+
self
.
lora_scale
*
self
.
to_v_lora
(
encoder_hidden_states
)
...
@@ -233,6 +242,22 @@ class LoRAIPAdapterAttnProcessor2_0(nn.Module):
...
@@ -233,6 +242,22 @@ class LoRAIPAdapterAttnProcessor2_0(nn.Module):
):
):
residual
=
hidden_states
residual
=
hidden_states
# separate ip_hidden_states from encoder_hidden_states
if
encoder_hidden_states
is
not
None
:
if
isinstance
(
encoder_hidden_states
,
tuple
):
encoder_hidden_states
,
ip_hidden_states
=
encoder_hidden_states
else
:
deprecation_message
=
(
"You have passed a tensor as `encoder_hidden_states`.This is deprecated and will be removed in a future release."
" Please make sure to update your script to pass `encoder_hidden_states` as a tuple to supress this warning."
)
deprecate
(
"encoder_hidden_states not a tuple"
,
"1.0.0"
,
deprecation_message
,
standard_warn
=
False
)
end_pos
=
encoder_hidden_states
.
shape
[
1
]
-
self
.
num_tokens
[
0
]
encoder_hidden_states
,
ip_hidden_states
=
(
encoder_hidden_states
[:,
:
end_pos
,
:],
[
encoder_hidden_states
[:,
end_pos
:,
:]],
)
if
attn
.
spatial_norm
is
not
None
:
if
attn
.
spatial_norm
is
not
None
:
hidden_states
=
attn
.
spatial_norm
(
hidden_states
,
temb
)
hidden_states
=
attn
.
spatial_norm
(
hidden_states
,
temb
)
...
@@ -259,15 +284,8 @@ class LoRAIPAdapterAttnProcessor2_0(nn.Module):
...
@@ -259,15 +284,8 @@ class LoRAIPAdapterAttnProcessor2_0(nn.Module):
if
encoder_hidden_states
is
None
:
if
encoder_hidden_states
is
None
:
encoder_hidden_states
=
hidden_states
encoder_hidden_states
=
hidden_states
else
:
elif
attn
.
norm_cross
:
# get encoder_hidden_states, ip_hidden_states
encoder_hidden_states
=
attn
.
norm_encoder_hidden_states
(
encoder_hidden_states
)
end_pos
=
encoder_hidden_states
.
shape
[
1
]
-
self
.
num_tokens
encoder_hidden_states
,
ip_hidden_states
=
(
encoder_hidden_states
[:,
:
end_pos
,
:],
encoder_hidden_states
[:,
end_pos
:,
:],
)
if
attn
.
norm_cross
:
encoder_hidden_states
=
attn
.
norm_encoder_hidden_states
(
encoder_hidden_states
)
key
=
attn
.
to_k
(
encoder_hidden_states
)
+
self
.
lora_scale
*
self
.
to_k_lora
(
encoder_hidden_states
)
key
=
attn
.
to_k
(
encoder_hidden_states
)
+
self
.
lora_scale
*
self
.
to_k_lora
(
encoder_hidden_states
)
value
=
attn
.
to_v
(
encoder_hidden_states
)
+
self
.
lora_scale
*
self
.
to_v_lora
(
encoder_hidden_states
)
value
=
attn
.
to_v
(
encoder_hidden_states
)
+
self
.
lora_scale
*
self
.
to_v_lora
(
encoder_hidden_states
)
...
@@ -951,30 +969,6 @@ class IPAdapterFaceIDStableDiffusionPipeline(
...
@@ -951,30 +969,6 @@ class IPAdapterFaceIDStableDiffusionPipeline(
return
prompt_embeds
,
negative_prompt_embeds
return
prompt_embeds
,
negative_prompt_embeds
def
encode_image
(
self
,
image
,
device
,
num_images_per_prompt
,
output_hidden_states
=
None
):
dtype
=
next
(
self
.
image_encoder
.
parameters
()).
dtype
if
not
isinstance
(
image
,
torch
.
Tensor
):
image
=
self
.
feature_extractor
(
image
,
return_tensors
=
"pt"
).
pixel_values
image
=
image
.
to
(
device
=
device
,
dtype
=
dtype
)
if
output_hidden_states
:
image_enc_hidden_states
=
self
.
image_encoder
(
image
,
output_hidden_states
=
True
).
hidden_states
[
-
2
]
image_enc_hidden_states
=
image_enc_hidden_states
.
repeat_interleave
(
num_images_per_prompt
,
dim
=
0
)
uncond_image_enc_hidden_states
=
self
.
image_encoder
(
torch
.
zeros_like
(
image
),
output_hidden_states
=
True
).
hidden_states
[
-
2
]
uncond_image_enc_hidden_states
=
uncond_image_enc_hidden_states
.
repeat_interleave
(
num_images_per_prompt
,
dim
=
0
)
return
image_enc_hidden_states
,
uncond_image_enc_hidden_states
else
:
image_embeds
=
self
.
image_encoder
(
image
).
image_embeds
image_embeds
=
image_embeds
.
repeat_interleave
(
num_images_per_prompt
,
dim
=
0
)
uncond_image_embeds
=
torch
.
zeros_like
(
image_embeds
)
return
image_embeds
,
uncond_image_embeds
def
run_safety_checker
(
self
,
image
,
device
,
dtype
):
def
run_safety_checker
(
self
,
image
,
device
,
dtype
):
if
self
.
safety_checker
is
None
:
if
self
.
safety_checker
is
None
:
has_nsfw_concept
=
None
has_nsfw_concept
=
None
...
@@ -1302,7 +1296,6 @@ class IPAdapterFaceIDStableDiffusionPipeline(
...
@@ -1302,7 +1296,6 @@ class IPAdapterFaceIDStableDiffusionPipeline(
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
image_embeds (`torch.FloatTensor`, *optional*):
image_embeds (`torch.FloatTensor`, *optional*):
Pre-generated image embeddings.
Pre-generated image embeddings.
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
output_type (`str`, *optional*, defaults to `"pil"`):
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
return_dict (`bool`, *optional*, defaults to `True`):
...
@@ -1411,7 +1404,7 @@ class IPAdapterFaceIDStableDiffusionPipeline(
...
@@ -1411,7 +1404,7 @@ class IPAdapterFaceIDStableDiffusionPipeline(
prompt_embeds
=
torch
.
cat
([
negative_prompt_embeds
,
prompt_embeds
])
prompt_embeds
=
torch
.
cat
([
negative_prompt_embeds
,
prompt_embeds
])
if
image_embeds
is
not
None
:
if
image_embeds
is
not
None
:
image_embeds
=
image_embeds
.
repeat_interleave
(
num_images_per_prompt
,
dim
=
0
).
to
(
image_embeds
=
torch
.
stack
([
image_embeds
]
*
num_images_per_prompt
,
dim
=
0
).
to
(
device
=
device
,
dtype
=
prompt_embeds
.
dtype
device
=
device
,
dtype
=
prompt_embeds
.
dtype
)
)
negative_image_embeds
=
torch
.
zeros_like
(
image_embeds
)
negative_image_embeds
=
torch
.
zeros_like
(
image_embeds
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment