Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
5915c298
Unverified
Commit
5915c298
authored
May 01, 2024
by
YiYi Xu
Committed by
GitHub
May 01, 2024
Browse files
[ip-adapter] fix ip-adapter for StableDiffusionInstructPix2PixPipeline (#7820)
update prepare_ip_adapter_ for pix2pix
parent
21a7ff12
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
87 additions
and
9 deletions
+87
-9
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py
...e_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py
+87
-8
src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py
...usion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py
+0
-1
No files found.
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py
View file @
5915c298
...
@@ -172,6 +172,7 @@ class StableDiffusionInstructPix2PixPipeline(
...
@@ -172,6 +172,7 @@ class StableDiffusionInstructPix2PixPipeline(
prompt_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
prompt_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
negative_prompt_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
negative_prompt_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
ip_adapter_image
:
Optional
[
PipelineImageInput
]
=
None
,
ip_adapter_image
:
Optional
[
PipelineImageInput
]
=
None
,
ip_adapter_image_embeds
:
Optional
[
List
[
torch
.
FloatTensor
]]
=
None
,
output_type
:
Optional
[
str
]
=
"pil"
,
output_type
:
Optional
[
str
]
=
"pil"
,
return_dict
:
bool
=
True
,
return_dict
:
bool
=
True
,
callback_on_step_end
:
Optional
[
Callable
[[
int
,
int
,
Dict
],
None
]]
=
None
,
callback_on_step_end
:
Optional
[
Callable
[[
int
,
int
,
Dict
],
None
]]
=
None
,
...
@@ -296,6 +297,8 @@ class StableDiffusionInstructPix2PixPipeline(
...
@@ -296,6 +297,8 @@ class StableDiffusionInstructPix2PixPipeline(
negative_prompt
,
negative_prompt
,
prompt_embeds
,
prompt_embeds
,
negative_prompt_embeds
,
negative_prompt_embeds
,
ip_adapter_image
,
ip_adapter_image_embeds
,
callback_on_step_end_tensor_inputs
,
callback_on_step_end_tensor_inputs
,
)
)
self
.
_guidance_scale
=
guidance_scale
self
.
_guidance_scale
=
guidance_scale
...
@@ -303,14 +306,6 @@ class StableDiffusionInstructPix2PixPipeline(
...
@@ -303,14 +306,6 @@ class StableDiffusionInstructPix2PixPipeline(
device
=
self
.
_execution_device
device
=
self
.
_execution_device
if
ip_adapter_image
is
not
None
:
output_hidden_state
=
False
if
isinstance
(
self
.
unet
.
encoder_hid_proj
,
ImageProjection
)
else
True
image_embeds
,
negative_image_embeds
=
self
.
encode_image
(
ip_adapter_image
,
device
,
num_images_per_prompt
,
output_hidden_state
)
if
self
.
do_classifier_free_guidance
:
image_embeds
=
torch
.
cat
([
image_embeds
,
negative_image_embeds
,
negative_image_embeds
])
if
image
is
None
:
if
image
is
None
:
raise
ValueError
(
"`image` input cannot be undefined."
)
raise
ValueError
(
"`image` input cannot be undefined."
)
...
@@ -335,6 +330,14 @@ class StableDiffusionInstructPix2PixPipeline(
...
@@ -335,6 +330,14 @@ class StableDiffusionInstructPix2PixPipeline(
negative_prompt_embeds
=
negative_prompt_embeds
,
negative_prompt_embeds
=
negative_prompt_embeds
,
)
)
if
ip_adapter_image
is
not
None
or
ip_adapter_image_embeds
is
not
None
:
image_embeds
=
self
.
prepare_ip_adapter_image_embeds
(
ip_adapter_image
,
ip_adapter_image_embeds
,
device
,
batch_size
*
num_images_per_prompt
,
self
.
do_classifier_free_guidance
,
)
# 3. Preprocess image
# 3. Preprocess image
image
=
self
.
image_processor
.
preprocess
(
image
)
image
=
self
.
image_processor
.
preprocess
(
image
)
...
@@ -635,6 +638,65 @@ class StableDiffusionInstructPix2PixPipeline(
...
@@ -635,6 +638,65 @@ class StableDiffusionInstructPix2PixPipeline(
return
image_embeds
,
uncond_image_embeds
return
image_embeds
,
uncond_image_embeds
def
prepare_ip_adapter_image_embeds
(
self
,
ip_adapter_image
,
ip_adapter_image_embeds
,
device
,
num_images_per_prompt
,
do_classifier_free_guidance
):
if
ip_adapter_image_embeds
is
None
:
if
not
isinstance
(
ip_adapter_image
,
list
):
ip_adapter_image
=
[
ip_adapter_image
]
if
len
(
ip_adapter_image
)
!=
len
(
self
.
unet
.
encoder_hid_proj
.
image_projection_layers
):
raise
ValueError
(
f
"`ip_adapter_image` must have same length as the number of IP Adapters. Got
{
len
(
ip_adapter_image
)
}
images and
{
len
(
self
.
unet
.
encoder_hid_proj
.
image_projection_layers
)
}
IP Adapters."
)
image_embeds
=
[]
for
single_ip_adapter_image
,
image_proj_layer
in
zip
(
ip_adapter_image
,
self
.
unet
.
encoder_hid_proj
.
image_projection_layers
):
output_hidden_state
=
not
isinstance
(
image_proj_layer
,
ImageProjection
)
single_image_embeds
,
single_negative_image_embeds
=
self
.
encode_image
(
single_ip_adapter_image
,
device
,
1
,
output_hidden_state
)
single_image_embeds
=
torch
.
stack
([
single_image_embeds
]
*
num_images_per_prompt
,
dim
=
0
)
single_negative_image_embeds
=
torch
.
stack
(
[
single_negative_image_embeds
]
*
num_images_per_prompt
,
dim
=
0
)
if
do_classifier_free_guidance
:
single_image_embeds
=
torch
.
cat
(
[
single_image_embeds
,
single_negative_image_embeds
,
single_negative_image_embeds
]
)
single_image_embeds
=
single_image_embeds
.
to
(
device
)
image_embeds
.
append
(
single_image_embeds
)
else
:
repeat_dims
=
[
1
]
image_embeds
=
[]
for
single_image_embeds
in
ip_adapter_image_embeds
:
if
do_classifier_free_guidance
:
(
single_image_embeds
,
single_negative_image_embeds
,
single_negative_image_embeds
,
)
=
single_image_embeds
.
chunk
(
3
)
single_image_embeds
=
single_image_embeds
.
repeat
(
num_images_per_prompt
,
*
(
repeat_dims
*
len
(
single_image_embeds
.
shape
[
1
:]))
)
single_negative_image_embeds
=
single_negative_image_embeds
.
repeat
(
num_images_per_prompt
,
*
(
repeat_dims
*
len
(
single_negative_image_embeds
.
shape
[
1
:]))
)
single_image_embeds
=
torch
.
cat
(
[
single_image_embeds
,
single_negative_image_embeds
,
single_negative_image_embeds
]
)
else
:
single_image_embeds
=
single_image_embeds
.
repeat
(
num_images_per_prompt
,
*
(
repeat_dims
*
len
(
single_image_embeds
.
shape
[
1
:]))
)
image_embeds
.
append
(
single_image_embeds
)
return
image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def
run_safety_checker
(
self
,
image
,
device
,
dtype
):
def
run_safety_checker
(
self
,
image
,
device
,
dtype
):
if
self
.
safety_checker
is
None
:
if
self
.
safety_checker
is
None
:
...
@@ -687,6 +749,8 @@ class StableDiffusionInstructPix2PixPipeline(
...
@@ -687,6 +749,8 @@ class StableDiffusionInstructPix2PixPipeline(
negative_prompt
=
None
,
negative_prompt
=
None
,
prompt_embeds
=
None
,
prompt_embeds
=
None
,
negative_prompt_embeds
=
None
,
negative_prompt_embeds
=
None
,
ip_adapter_image
=
None
,
ip_adapter_image_embeds
=
None
,
callback_on_step_end_tensor_inputs
=
None
,
callback_on_step_end_tensor_inputs
=
None
,
):
):
if
callback_steps
is
not
None
and
(
not
isinstance
(
callback_steps
,
int
)
or
callback_steps
<=
0
):
if
callback_steps
is
not
None
and
(
not
isinstance
(
callback_steps
,
int
)
or
callback_steps
<=
0
):
...
@@ -728,6 +792,21 @@ class StableDiffusionInstructPix2PixPipeline(
...
@@ -728,6 +792,21 @@ class StableDiffusionInstructPix2PixPipeline(
f
"
{
negative_prompt_embeds
.
shape
}
."
f
"
{
negative_prompt_embeds
.
shape
}
."
)
)
if
ip_adapter_image
is
not
None
and
ip_adapter_image_embeds
is
not
None
:
raise
ValueError
(
"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
)
if
ip_adapter_image_embeds
is
not
None
:
if
not
isinstance
(
ip_adapter_image_embeds
,
list
):
raise
ValueError
(
f
"`ip_adapter_image_embeds` has to be of type `list` but is
{
type
(
ip_adapter_image_embeds
)
}
"
)
elif
ip_adapter_image_embeds
[
0
].
ndim
not
in
[
3
,
4
]:
raise
ValueError
(
f
"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is
{
ip_adapter_image_embeds
[
0
].
ndim
}
D"
)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def
prepare_latents
(
self
,
batch_size
,
num_channels_latents
,
height
,
width
,
dtype
,
device
,
generator
,
latents
=
None
):
def
prepare_latents
(
self
,
batch_size
,
num_channels_latents
,
height
,
width
,
dtype
,
device
,
generator
,
latents
=
None
):
shape
=
(
shape
=
(
...
...
src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py
View file @
5915c298
...
@@ -436,7 +436,6 @@ class StableDiffusionXLInstructPix2PixPipeline(
...
@@ -436,7 +436,6 @@ class StableDiffusionXLInstructPix2PixPipeline(
extra_step_kwargs
[
"generator"
]
=
generator
extra_step_kwargs
[
"generator"
]
=
generator
return
extra_step_kwargs
return
extra_step_kwargs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_instruct_pix2pix.StableDiffusionInstructPix2PixPipeline.check_inputs
def
check_inputs
(
def
check_inputs
(
self
,
self
,
prompt
,
prompt
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment