Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
5915c298
Unverified
Commit
5915c298
authored
May 01, 2024
by
YiYi Xu
Committed by
GitHub
May 01, 2024
Browse files
[ip-adapter] fix ip-adapter for StableDiffusionInstructPix2PixPipeline (#7820)
update prepare_ip_adapter_ for pix2pix
parent
21a7ff12
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
87 additions
and
9 deletions
+87
-9
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py
...e_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py
+87
-8
src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py
...usion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py
+0
-1
No files found.
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py
View file @
5915c298
...
@@ -172,6 +172,7 @@ class StableDiffusionInstructPix2PixPipeline(
...
@@ -172,6 +172,7 @@ class StableDiffusionInstructPix2PixPipeline(
prompt_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
prompt_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
negative_prompt_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
negative_prompt_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
ip_adapter_image
:
Optional
[
PipelineImageInput
]
=
None
,
ip_adapter_image
:
Optional
[
PipelineImageInput
]
=
None
,
ip_adapter_image_embeds
:
Optional
[
List
[
torch
.
FloatTensor
]]
=
None
,
output_type
:
Optional
[
str
]
=
"pil"
,
output_type
:
Optional
[
str
]
=
"pil"
,
return_dict
:
bool
=
True
,
return_dict
:
bool
=
True
,
callback_on_step_end
:
Optional
[
Callable
[[
int
,
int
,
Dict
],
None
]]
=
None
,
callback_on_step_end
:
Optional
[
Callable
[[
int
,
int
,
Dict
],
None
]]
=
None
,
...
@@ -296,6 +297,8 @@ class StableDiffusionInstructPix2PixPipeline(
...
@@ -296,6 +297,8 @@ class StableDiffusionInstructPix2PixPipeline(
negative_prompt
,
negative_prompt
,
prompt_embeds
,
prompt_embeds
,
negative_prompt_embeds
,
negative_prompt_embeds
,
ip_adapter_image
,
ip_adapter_image_embeds
,
callback_on_step_end_tensor_inputs
,
callback_on_step_end_tensor_inputs
,
)
)
self
.
_guidance_scale
=
guidance_scale
self
.
_guidance_scale
=
guidance_scale
...
@@ -303,14 +306,6 @@ class StableDiffusionInstructPix2PixPipeline(
...
@@ -303,14 +306,6 @@ class StableDiffusionInstructPix2PixPipeline(
device
=
self
.
_execution_device
device
=
self
.
_execution_device
if
ip_adapter_image
is
not
None
:
output_hidden_state
=
False
if
isinstance
(
self
.
unet
.
encoder_hid_proj
,
ImageProjection
)
else
True
image_embeds
,
negative_image_embeds
=
self
.
encode_image
(
ip_adapter_image
,
device
,
num_images_per_prompt
,
output_hidden_state
)
if
self
.
do_classifier_free_guidance
:
image_embeds
=
torch
.
cat
([
image_embeds
,
negative_image_embeds
,
negative_image_embeds
])
if
image
is
None
:
if
image
is
None
:
raise
ValueError
(
"`image` input cannot be undefined."
)
raise
ValueError
(
"`image` input cannot be undefined."
)
...
@@ -335,6 +330,14 @@ class StableDiffusionInstructPix2PixPipeline(
...
@@ -335,6 +330,14 @@ class StableDiffusionInstructPix2PixPipeline(
negative_prompt_embeds
=
negative_prompt_embeds
,
negative_prompt_embeds
=
negative_prompt_embeds
,
)
)
if
ip_adapter_image
is
not
None
or
ip_adapter_image_embeds
is
not
None
:
image_embeds
=
self
.
prepare_ip_adapter_image_embeds
(
ip_adapter_image
,
ip_adapter_image_embeds
,
device
,
batch_size
*
num_images_per_prompt
,
self
.
do_classifier_free_guidance
,
)
# 3. Preprocess image
# 3. Preprocess image
image
=
self
.
image_processor
.
preprocess
(
image
)
image
=
self
.
image_processor
.
preprocess
(
image
)
...
@@ -635,6 +638,65 @@ class StableDiffusionInstructPix2PixPipeline(
...
@@ -635,6 +638,65 @@ class StableDiffusionInstructPix2PixPipeline(
return
image_embeds
,
uncond_image_embeds
return
image_embeds
,
uncond_image_embeds
def
prepare_ip_adapter_image_embeds
(
self
,
ip_adapter_image
,
ip_adapter_image_embeds
,
device
,
num_images_per_prompt
,
do_classifier_free_guidance
):
if
ip_adapter_image_embeds
is
None
:
if
not
isinstance
(
ip_adapter_image
,
list
):
ip_adapter_image
=
[
ip_adapter_image
]
if
len
(
ip_adapter_image
)
!=
len
(
self
.
unet
.
encoder_hid_proj
.
image_projection_layers
):
raise
ValueError
(
f
"`ip_adapter_image` must have same length as the number of IP Adapters. Got
{
len
(
ip_adapter_image
)
}
images and
{
len
(
self
.
unet
.
encoder_hid_proj
.
image_projection_layers
)
}
IP Adapters."
)
image_embeds
=
[]
for
single_ip_adapter_image
,
image_proj_layer
in
zip
(
ip_adapter_image
,
self
.
unet
.
encoder_hid_proj
.
image_projection_layers
):
output_hidden_state
=
not
isinstance
(
image_proj_layer
,
ImageProjection
)
single_image_embeds
,
single_negative_image_embeds
=
self
.
encode_image
(
single_ip_adapter_image
,
device
,
1
,
output_hidden_state
)
single_image_embeds
=
torch
.
stack
([
single_image_embeds
]
*
num_images_per_prompt
,
dim
=
0
)
single_negative_image_embeds
=
torch
.
stack
(
[
single_negative_image_embeds
]
*
num_images_per_prompt
,
dim
=
0
)
if
do_classifier_free_guidance
:
single_image_embeds
=
torch
.
cat
(
[
single_image_embeds
,
single_negative_image_embeds
,
single_negative_image_embeds
]
)
single_image_embeds
=
single_image_embeds
.
to
(
device
)
image_embeds
.
append
(
single_image_embeds
)
else
:
repeat_dims
=
[
1
]
image_embeds
=
[]
for
single_image_embeds
in
ip_adapter_image_embeds
:
if
do_classifier_free_guidance
:
(
single_image_embeds
,
single_negative_image_embeds
,
single_negative_image_embeds
,
)
=
single_image_embeds
.
chunk
(
3
)
single_image_embeds
=
single_image_embeds
.
repeat
(
num_images_per_prompt
,
*
(
repeat_dims
*
len
(
single_image_embeds
.
shape
[
1
:]))
)
single_negative_image_embeds
=
single_negative_image_embeds
.
repeat
(
num_images_per_prompt
,
*
(
repeat_dims
*
len
(
single_negative_image_embeds
.
shape
[
1
:]))
)
single_image_embeds
=
torch
.
cat
(
[
single_image_embeds
,
single_negative_image_embeds
,
single_negative_image_embeds
]
)
else
:
single_image_embeds
=
single_image_embeds
.
repeat
(
num_images_per_prompt
,
*
(
repeat_dims
*
len
(
single_image_embeds
.
shape
[
1
:]))
)
image_embeds
.
append
(
single_image_embeds
)
return
image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def
run_safety_checker
(
self
,
image
,
device
,
dtype
):
def
run_safety_checker
(
self
,
image
,
device
,
dtype
):
if
self
.
safety_checker
is
None
:
if
self
.
safety_checker
is
None
:
...
@@ -687,6 +749,8 @@ class StableDiffusionInstructPix2PixPipeline(
...
@@ -687,6 +749,8 @@ class StableDiffusionInstructPix2PixPipeline(
negative_prompt
=
None
,
negative_prompt
=
None
,
prompt_embeds
=
None
,
prompt_embeds
=
None
,
negative_prompt_embeds
=
None
,
negative_prompt_embeds
=
None
,
ip_adapter_image
=
None
,
ip_adapter_image_embeds
=
None
,
callback_on_step_end_tensor_inputs
=
None
,
callback_on_step_end_tensor_inputs
=
None
,
):
):
if
callback_steps
is
not
None
and
(
not
isinstance
(
callback_steps
,
int
)
or
callback_steps
<=
0
):
if
callback_steps
is
not
None
and
(
not
isinstance
(
callback_steps
,
int
)
or
callback_steps
<=
0
):
...
@@ -728,6 +792,21 @@ class StableDiffusionInstructPix2PixPipeline(
...
@@ -728,6 +792,21 @@ class StableDiffusionInstructPix2PixPipeline(
f
"
{
negative_prompt_embeds
.
shape
}
."
f
"
{
negative_prompt_embeds
.
shape
}
."
)
)
if
ip_adapter_image
is
not
None
and
ip_adapter_image_embeds
is
not
None
:
raise
ValueError
(
"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
)
if
ip_adapter_image_embeds
is
not
None
:
if
not
isinstance
(
ip_adapter_image_embeds
,
list
):
raise
ValueError
(
f
"`ip_adapter_image_embeds` has to be of type `list` but is
{
type
(
ip_adapter_image_embeds
)
}
"
)
elif
ip_adapter_image_embeds
[
0
].
ndim
not
in
[
3
,
4
]:
raise
ValueError
(
f
"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is
{
ip_adapter_image_embeds
[
0
].
ndim
}
D"
)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def
prepare_latents
(
self
,
batch_size
,
num_channels_latents
,
height
,
width
,
dtype
,
device
,
generator
,
latents
=
None
):
def
prepare_latents
(
self
,
batch_size
,
num_channels_latents
,
height
,
width
,
dtype
,
device
,
generator
,
latents
=
None
):
shape
=
(
shape
=
(
...
...
src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py
View file @
5915c298
...
@@ -436,7 +436,6 @@ class StableDiffusionXLInstructPix2PixPipeline(
...
@@ -436,7 +436,6 @@ class StableDiffusionXLInstructPix2PixPipeline(
extra_step_kwargs
[
"generator"
]
=
generator
extra_step_kwargs
[
"generator"
]
=
generator
return
extra_step_kwargs
return
extra_step_kwargs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_instruct_pix2pix.StableDiffusionInstructPix2PixPipeline.check_inputs
def
check_inputs
(
def
check_inputs
(
self
,
self
,
prompt
,
prompt
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment