Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
e8b65bff
Unverified
Commit
e8b65bff
authored
Dec 12, 2024
by
hlky
Committed by
GitHub
Dec 12, 2024
Browse files
refactor StableDiffusionXLControlNetUnion (#10200)
mode
parent
f2d348d9
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
185 additions
and
310 deletions
+185
-310
src/diffusers/models/controlnets/__init__.py
src/diffusers/models/controlnets/__init__.py
+1
-1
src/diffusers/models/controlnets/controlnet_union.py
src/diffusers/models/controlnets/controlnet_union.py
+8
-93
src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py
...nes/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py
+60
-68
src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py
...s/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py
+56
-77
src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py
...nes/controlnet/pipeline_controlnet_union_sd_xl_img2img.py
+60
-71
No files found.
src/diffusers/models/controlnets/__init__.py
View file @
e8b65bff
...
@@ -15,7 +15,7 @@ if is_torch_available():
...
@@ -15,7 +15,7 @@ if is_torch_available():
SparseControlNetModel
,
SparseControlNetModel
,
SparseControlNetOutput
,
SparseControlNetOutput
,
)
)
from
.controlnet_union
import
ControlNetUnionInput
,
ControlNetUnionInputProMax
,
ControlNetUnionModel
from
.controlnet_union
import
ControlNetUnionModel
from
.controlnet_xs
import
ControlNetXSAdapter
,
ControlNetXSOutput
,
UNetControlNetXSModel
from
.controlnet_xs
import
ControlNetXSAdapter
,
ControlNetXSOutput
,
UNetControlNetXSModel
from
.multicontrolnet
import
MultiControlNetModel
from
.multicontrolnet
import
MultiControlNetModel
...
...
src/diffusers/models/controlnets/controlnet_union.py
View file @
e8b65bff
...
@@ -11,14 +11,12 @@
...
@@ -11,14 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
...configuration_utils
import
ConfigMixin
,
register_to_config
from
...configuration_utils
import
ConfigMixin
,
register_to_config
from
...image_processor
import
PipelineImageInput
from
...loaders.single_file_model
import
FromOriginalModelMixin
from
...loaders.single_file_model
import
FromOriginalModelMixin
from
...utils
import
logging
from
...utils
import
logging
from
..attention_processor
import
(
from
..attention_processor
import
(
...
@@ -40,76 +38,6 @@ from ..unets.unet_2d_condition import UNet2DConditionModel
...
@@ -40,76 +38,6 @@ from ..unets.unet_2d_condition import UNet2DConditionModel
from
.controlnet
import
ControlNetConditioningEmbedding
,
ControlNetOutput
,
zero_module
from
.controlnet
import
ControlNetConditioningEmbedding
,
ControlNetOutput
,
zero_module
@
dataclass
class
ControlNetUnionInput
:
"""
The image input of [`ControlNetUnionModel`]:
- 0: openpose
- 1: depth
- 2: hed/pidi/scribble/ted
- 3: canny/lineart/anime_lineart/mlsd
- 4: normal
- 5: segment
"""
openpose
:
Optional
[
PipelineImageInput
]
=
None
depth
:
Optional
[
PipelineImageInput
]
=
None
hed
:
Optional
[
PipelineImageInput
]
=
None
canny
:
Optional
[
PipelineImageInput
]
=
None
normal
:
Optional
[
PipelineImageInput
]
=
None
segment
:
Optional
[
PipelineImageInput
]
=
None
def
__len__
(
self
)
->
int
:
return
len
(
vars
(
self
))
def
__iter__
(
self
):
return
iter
(
vars
(
self
))
def
__getitem__
(
self
,
key
):
return
getattr
(
self
,
key
)
def
__setitem__
(
self
,
key
,
value
):
setattr
(
self
,
key
,
value
)
@
dataclass
class
ControlNetUnionInputProMax
:
"""
The image input of [`ControlNetUnionModel`]:
- 0: openpose
- 1: depth
- 2: hed/pidi/scribble/ted
- 3: canny/lineart/anime_lineart/mlsd
- 4: normal
- 5: segment
- 6: tile
- 7: repaint
"""
openpose
:
Optional
[
PipelineImageInput
]
=
None
depth
:
Optional
[
PipelineImageInput
]
=
None
hed
:
Optional
[
PipelineImageInput
]
=
None
canny
:
Optional
[
PipelineImageInput
]
=
None
normal
:
Optional
[
PipelineImageInput
]
=
None
segment
:
Optional
[
PipelineImageInput
]
=
None
tile
:
Optional
[
PipelineImageInput
]
=
None
repaint
:
Optional
[
PipelineImageInput
]
=
None
def
__len__
(
self
)
->
int
:
return
len
(
vars
(
self
))
def
__iter__
(
self
):
return
iter
(
vars
(
self
))
def
__getitem__
(
self
,
key
):
return
getattr
(
self
,
key
)
def
__setitem__
(
self
,
key
,
value
):
setattr
(
self
,
key
,
value
)
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
...
@@ -680,8 +608,9 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
...
@@ -680,8 +608,9 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
sample
:
torch
.
Tensor
,
sample
:
torch
.
Tensor
,
timestep
:
Union
[
torch
.
Tensor
,
float
,
int
],
timestep
:
Union
[
torch
.
Tensor
,
float
,
int
],
encoder_hidden_states
:
torch
.
Tensor
,
encoder_hidden_states
:
torch
.
Tensor
,
controlnet_cond
:
Union
[
ControlNetUnionInput
,
ControlNetUnionInputProMax
],
controlnet_cond
:
List
[
torch
.
Tensor
],
control_type
:
torch
.
Tensor
,
control_type
:
torch
.
Tensor
,
control_type_idx
:
List
[
int
],
conditioning_scale
:
float
=
1.0
,
conditioning_scale
:
float
=
1.0
,
class_labels
:
Optional
[
torch
.
Tensor
]
=
None
,
class_labels
:
Optional
[
torch
.
Tensor
]
=
None
,
timestep_cond
:
Optional
[
torch
.
Tensor
]
=
None
,
timestep_cond
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -701,11 +630,13 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
...
@@ -701,11 +630,13 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
The number of timesteps to denoise an input.
The number of timesteps to denoise an input.
encoder_hidden_states (`torch.Tensor`):
encoder_hidden_states (`torch.Tensor`):
The encoder hidden states.
The encoder hidden states.
controlnet_cond (`
Union[ControlNetUnionInput, ControlNetUnionInputProMax
]`):
controlnet_cond (`
List[torch.Tensor
]`):
The conditional input tensors.
The conditional input tensors.
control_type (`torch.Tensor`):
control_type (`torch.Tensor`):
A tensor of shape `(batch, num_control_type)` with values `0` or `1` depending on whether the control
A tensor of shape `(batch, num_control_type)` with values `0` or `1` depending on whether the control
type is used.
type is used.
control_type_idx (`List[int]`):
The indices of `control_type`.
conditioning_scale (`float`, defaults to `1.0`):
conditioning_scale (`float`, defaults to `1.0`):
The scale factor for ControlNet outputs.
The scale factor for ControlNet outputs.
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
...
@@ -733,20 +664,6 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
...
@@ -733,20 +664,6 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
returned where the first element is the sample tensor.
returned where the first element is the sample tensor.
"""
"""
if
not
isinstance
(
controlnet_cond
,
(
ControlNetUnionInput
,
ControlNetUnionInputProMax
)):
raise
ValueError
(
"Expected type of `controlnet_cond` to be one of `ControlNetUnionInput` or `ControlNetUnionInputProMax`"
)
if
len
(
controlnet_cond
)
!=
self
.
config
.
num_control_type
:
if
isinstance
(
controlnet_cond
,
ControlNetUnionInput
):
raise
ValueError
(
f
"Expected num_control_type
{
self
.
config
.
num_control_type
}
, got
{
len
(
controlnet_cond
)
}
. Try `ControlNetUnionInputProMax`."
)
elif
isinstance
(
controlnet_cond
,
ControlNetUnionInputProMax
):
raise
ValueError
(
f
"Expected num_control_type
{
self
.
config
.
num_control_type
}
, got
{
len
(
controlnet_cond
)
}
. Try `ControlNetUnionInput`."
)
# check channel order
# check channel order
channel_order
=
self
.
config
.
controlnet_conditioning_channel_order
channel_order
=
self
.
config
.
controlnet_conditioning_channel_order
...
@@ -830,12 +747,10 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
...
@@ -830,12 +747,10 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
inputs
=
[]
inputs
=
[]
condition_list
=
[]
condition_list
=
[]
for
idx
,
image_type
in
enumerate
(
controlnet_cond
):
for
cond
,
control_idx
in
zip
(
controlnet_cond
,
control_type_idx
):
if
controlnet_cond
[
image_type
]
is
None
:
condition
=
self
.
controlnet_cond_embedding
(
cond
)
continue
condition
=
self
.
controlnet_cond_embedding
(
controlnet_cond
[
image_type
])
feat_seq
=
torch
.
mean
(
condition
,
dim
=
(
2
,
3
))
feat_seq
=
torch
.
mean
(
condition
,
dim
=
(
2
,
3
))
feat_seq
=
feat_seq
+
self
.
task_embedding
[
idx
]
feat_seq
=
feat_seq
+
self
.
task_embedding
[
control_
idx
]
inputs
.
append
(
feat_seq
.
unsqueeze
(
1
))
inputs
.
append
(
feat_seq
.
unsqueeze
(
1
))
condition_list
.
append
(
condition
)
condition_list
.
append
(
condition
)
...
...
src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py
View file @
e8b65bff
...
@@ -40,7 +40,6 @@ from ...models.attention_processor import (
...
@@ -40,7 +40,6 @@ from ...models.attention_processor import (
AttnProcessor2_0
,
AttnProcessor2_0
,
XFormersAttnProcessor
,
XFormersAttnProcessor
,
)
)
from
...models.controlnets
import
ControlNetUnionInput
,
ControlNetUnionInputProMax
from
...models.lora
import
adjust_lora_scale_text_encoder
from
...models.lora
import
adjust_lora_scale_text_encoder
from
...schedulers
import
KarrasDiffusionSchedulers
from
...schedulers
import
KarrasDiffusionSchedulers
from
...utils
import
(
from
...utils
import
(
...
@@ -82,7 +81,6 @@ EXAMPLE_DOC_STRING = """
...
@@ -82,7 +81,6 @@ EXAMPLE_DOC_STRING = """
Examples:
Examples:
```py
```py
from diffusers import StableDiffusionXLControlNetUnionInpaintPipeline, ControlNetUnionModel, AutoencoderKL
from diffusers import StableDiffusionXLControlNetUnionInpaintPipeline, ControlNetUnionModel, AutoencoderKL
from diffusers.models.controlnets import ControlNetUnionInputProMax
from diffusers.utils import load_image
from diffusers.utils import load_image
import torch
import torch
import numpy as np
import numpy as np
...
@@ -114,11 +112,8 @@ EXAMPLE_DOC_STRING = """
...
@@ -114,11 +112,8 @@ EXAMPLE_DOC_STRING = """
mask_np = np.array(mask)
mask_np = np.array(mask)
controlnet_img_np[mask_np > 0] = 0
controlnet_img_np[mask_np > 0] = 0
controlnet_img = Image.fromarray(controlnet_img_np)
controlnet_img = Image.fromarray(controlnet_img_np)
union_input = ControlNetUnionInputProMax(
repaint=controlnet_img,
)
# generate image
# generate image
image = pipe(prompt, image=image, mask_image=mask, control_image
_list=union_input
).images[0]
image = pipe(prompt, image=image, mask_image=mask, control_image
=[controlnet_img], control_mode=[7]
).images[0]
image.save("inpaint.png")
image.save("inpaint.png")
```
```
"""
"""
...
@@ -1130,7 +1125,7 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
...
@@ -1130,7 +1125,7 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
prompt_2
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
prompt_2
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
image
:
PipelineImageInput
=
None
,
image
:
PipelineImageInput
=
None
,
mask_image
:
PipelineImageInput
=
None
,
mask_image
:
PipelineImageInput
=
None
,
control_image
_list
:
Union
[
ControlNetUnionInput
,
ControlNetUnionInputProMax
]
=
None
,
control_image
:
PipelineImageInput
=
None
,
height
:
Optional
[
int
]
=
None
,
height
:
Optional
[
int
]
=
None
,
width
:
Optional
[
int
]
=
None
,
width
:
Optional
[
int
]
=
None
,
padding_mask_crop
:
Optional
[
int
]
=
None
,
padding_mask_crop
:
Optional
[
int
]
=
None
,
...
@@ -1158,6 +1153,7 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
...
@@ -1158,6 +1153,7 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
guess_mode
:
bool
=
False
,
guess_mode
:
bool
=
False
,
control_guidance_start
:
Union
[
float
,
List
[
float
]]
=
0.0
,
control_guidance_start
:
Union
[
float
,
List
[
float
]]
=
0.0
,
control_guidance_end
:
Union
[
float
,
List
[
float
]]
=
1.0
,
control_guidance_end
:
Union
[
float
,
List
[
float
]]
=
1.0
,
control_mode
:
Optional
[
Union
[
int
,
List
[
int
]]]
=
None
,
guidance_rescale
:
float
=
0.0
,
guidance_rescale
:
float
=
0.0
,
original_size
:
Tuple
[
int
,
int
]
=
None
,
original_size
:
Tuple
[
int
,
int
]
=
None
,
crops_coords_top_left
:
Tuple
[
int
,
int
]
=
(
0
,
0
),
crops_coords_top_left
:
Tuple
[
int
,
int
]
=
(
0
,
0
),
...
@@ -1345,20 +1341,6 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
...
@@ -1345,20 +1341,6 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
controlnet
=
self
.
controlnet
.
_orig_mod
if
is_compiled_module
(
self
.
controlnet
)
else
self
.
controlnet
controlnet
=
self
.
controlnet
.
_orig_mod
if
is_compiled_module
(
self
.
controlnet
)
else
self
.
controlnet
if
not
isinstance
(
control_image_list
,
(
ControlNetUnionInput
,
ControlNetUnionInputProMax
)):
raise
ValueError
(
"Expected type of `control_image_list` to be one of `ControlNetUnionInput` or `ControlNetUnionInputProMax`"
)
if
len
(
control_image_list
)
!=
controlnet
.
config
.
num_control_type
:
if
isinstance
(
control_image_list
,
ControlNetUnionInput
):
raise
ValueError
(
f
"Expected num_control_type
{
controlnet
.
config
.
num_control_type
}
, got
{
len
(
control_image_list
)
}
. Try `ControlNetUnionInputProMax`."
)
elif
isinstance
(
control_image_list
,
ControlNetUnionInputProMax
):
raise
ValueError
(
f
"Expected num_control_type
{
controlnet
.
config
.
num_control_type
}
, got
{
len
(
control_image_list
)
}
. Try `ControlNetUnionInput`."
)
# align format for control guidance
# align format for control guidance
if
not
isinstance
(
control_guidance_start
,
list
)
and
isinstance
(
control_guidance_end
,
list
):
if
not
isinstance
(
control_guidance_start
,
list
)
and
isinstance
(
control_guidance_end
,
list
):
control_guidance_start
=
len
(
control_guidance_end
)
*
[
control_guidance_start
]
control_guidance_start
=
len
(
control_guidance_end
)
*
[
control_guidance_start
]
...
@@ -1375,36 +1357,44 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
...
@@ -1375,36 +1357,44 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
elif
not
isinstance
(
control_guidance_end
,
list
)
and
isinstance
(
control_guidance_start
,
list
):
elif
not
isinstance
(
control_guidance_end
,
list
)
and
isinstance
(
control_guidance_start
,
list
):
control_guidance_end
=
len
(
control_guidance_start
)
*
[
control_guidance_end
]
control_guidance_end
=
len
(
control_guidance_start
)
*
[
control_guidance_end
]
if
not
isinstance
(
control_image
,
list
):
control_image
=
[
control_image
]
if
not
isinstance
(
control_mode
,
list
):
control_mode
=
[
control_mode
]
if
len
(
control_image
)
!=
len
(
control_mode
):
raise
ValueError
(
"Expected len(control_image) == len(control_type)"
)
num_control_type
=
controlnet
.
config
.
num_control_type
# 1. Check inputs
# 1. Check inputs
control_type
=
[]
control_type
=
[
0
for
_
in
range
(
num_control_type
)]
for
image_type
in
control_image_list
:
for
_image
,
control_idx
in
zip
(
control_image
,
control_mode
):
if
control_image_list
[
image_type
]:
control_type
[
control_idx
]
=
1
self
.
check_inputs
(
self
.
check_inputs
(
prompt
,
prompt
,
prompt_2
,
prompt_2
,
control_image_list
[
image_type
],
_image
,
mask_image
,
mask_image
,
strength
,
strength
,
num_inference_steps
,
num_inference_steps
,
callback_steps
,
callback_steps
,
output_type
,
output_type
,
negative_prompt
,
negative_prompt
,
negative_prompt_2
,
negative_prompt_2
,
prompt_embeds
,
prompt_embeds
,
negative_prompt_embeds
,
negative_prompt_embeds
,
ip_adapter_image
,
ip_adapter_image
,
ip_adapter_image_embeds
,
ip_adapter_image_embeds
,
pooled_prompt_embeds
,
pooled_prompt_embeds
,
negative_pooled_prompt_embeds
,
negative_pooled_prompt_embeds
,
controlnet_conditioning_scale
,
controlnet_conditioning_scale
,
control_guidance_start
,
control_guidance_start
,
control_guidance_end
,
control_guidance_end
,
callback_on_step_end_tensor_inputs
,
callback_on_step_end_tensor_inputs
,
padding_mask_crop
,
padding_mask_crop
,
)
)
control_type
.
append
(
1
)
else
:
control_type
.
append
(
0
)
control_type
=
torch
.
Tensor
(
control_type
)
control_type
=
torch
.
Tensor
(
control_type
)
...
@@ -1499,23 +1489,21 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
...
@@ -1499,23 +1489,21 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
init_image
=
init_image
.
to
(
dtype
=
torch
.
float32
)
init_image
=
init_image
.
to
(
dtype
=
torch
.
float32
)
# 5.2 Prepare control images
# 5.2 Prepare control images
for
image_type
in
control_image_list
:
for
idx
,
_
in
enumerate
(
control_image
):
if
control_image_list
[
image_type
]:
control_image
[
idx
]
=
self
.
prepare_control_image
(
control_image
=
self
.
prepare_control_image
(
image
=
control_image
[
idx
],
image
=
control_image_list
[
image_type
],
width
=
width
,
width
=
width
,
height
=
height
,
height
=
height
,
batch_size
=
batch_size
*
num_images_per_prompt
,
batch_size
=
batch_size
*
num_images_per_prompt
,
num_images_per_prompt
=
num_images_per_prompt
,
num_images_per_prompt
=
num_images_per_prompt
,
device
=
device
,
device
=
device
,
dtype
=
controlnet
.
dtype
,
dtype
=
controlnet
.
dtype
,
crops_coords
=
crops_coords
,
crops_coords
=
crops_coords
,
resize_mode
=
resize_mode
,
resize_mode
=
resize_mode
,
do_classifier_free_guidance
=
self
.
do_classifier_free_guidance
,
do_classifier_free_guidance
=
self
.
do_classifier_free_guidance
,
guess_mode
=
guess_mode
,
guess_mode
=
guess_mode
,
)
)
height
,
width
=
control_image
[
idx
].
shape
[
-
2
:]
height
,
width
=
control_image
.
shape
[
-
2
:]
control_image_list
[
image_type
]
=
control_image
# 5.3 Prepare mask
# 5.3 Prepare mask
mask
=
self
.
mask_processor
.
preprocess
(
mask
=
self
.
mask_processor
.
preprocess
(
...
@@ -1589,6 +1577,9 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
...
@@ -1589,6 +1577,9 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
original_size
=
original_size
or
(
height
,
width
)
original_size
=
original_size
or
(
height
,
width
)
target_size
=
target_size
or
(
height
,
width
)
target_size
=
target_size
or
(
height
,
width
)
for
_image
in
control_image
:
if
isinstance
(
_image
,
torch
.
Tensor
):
original_size
=
original_size
or
_image
.
shape
[
-
2
:]
# 10. Prepare added time ids & embeddings
# 10. Prepare added time ids & embeddings
add_text_embeds
=
pooled_prompt_embeds
add_text_embeds
=
pooled_prompt_embeds
...
@@ -1693,8 +1684,9 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
...
@@ -1693,8 +1684,9 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
control_model_input
,
control_model_input
,
t
,
t
,
encoder_hidden_states
=
controlnet_prompt_embeds
,
encoder_hidden_states
=
controlnet_prompt_embeds
,
controlnet_cond
=
control_image
_list
,
controlnet_cond
=
control_image
,
control_type
=
control_type
,
control_type
=
control_type
,
control_type_idx
=
control_mode
,
conditioning_scale
=
cond_scale
,
conditioning_scale
=
cond_scale
,
guess_mode
=
guess_mode
,
guess_mode
=
guess_mode
,
added_cond_kwargs
=
controlnet_added_cond_kwargs
,
added_cond_kwargs
=
controlnet_added_cond_kwargs
,
...
...
src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py
View file @
e8b65bff
...
@@ -43,7 +43,6 @@ from ...models.attention_processor import (
...
@@ -43,7 +43,6 @@ from ...models.attention_processor import (
AttnProcessor2_0
,
AttnProcessor2_0
,
XFormersAttnProcessor
,
XFormersAttnProcessor
,
)
)
from
...models.controlnets
import
ControlNetUnionInput
,
ControlNetUnionInputProMax
from
...models.lora
import
adjust_lora_scale_text_encoder
from
...models.lora
import
adjust_lora_scale_text_encoder
from
...schedulers
import
KarrasDiffusionSchedulers
from
...schedulers
import
KarrasDiffusionSchedulers
from
...utils
import
(
from
...utils
import
(
...
@@ -70,7 +69,6 @@ EXAMPLE_DOC_STRING = """
...
@@ -70,7 +69,6 @@ EXAMPLE_DOC_STRING = """
>>> # !pip install controlnet_aux
>>> # !pip install controlnet_aux
>>> from controlnet_aux import LineartAnimeDetector
>>> from controlnet_aux import LineartAnimeDetector
>>> from diffusers import StableDiffusionXLControlNetUnionPipeline, ControlNetUnionModel, AutoencoderKL
>>> from diffusers import StableDiffusionXLControlNetUnionPipeline, ControlNetUnionModel, AutoencoderKL
>>> from diffusers.models.controlnets import ControlNetUnionInput
>>> from diffusers.utils import load_image
>>> from diffusers.utils import load_image
>>> import torch
>>> import torch
...
@@ -89,17 +87,14 @@ EXAMPLE_DOC_STRING = """
...
@@ -89,17 +87,14 @@ EXAMPLE_DOC_STRING = """
... controlnet=controlnet,
... controlnet=controlnet,
... vae=vae,
... vae=vae,
... torch_dtype=torch.float16,
... torch_dtype=torch.float16,
... variant="fp16",
... )
... )
>>> pipe.enable_model_cpu_offload()
>>> pipe.enable_model_cpu_offload()
>>> # prepare image
>>> # prepare image
>>> processor = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators")
>>> processor = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators")
>>> controlnet_img = processor(image, output_type="pil")
>>> controlnet_img = processor(image, output_type="pil")
>>> # set ControlNetUnion input
>>> union_input = ControlNetUnionInput(
... canny=controlnet_img,
... )
>>> # generate image
>>> # generate image
>>> image = pipe(prompt,
image=union_input
).images[0]
>>> image = pipe(prompt,
control_image=[controlnet_img], control_mode=[3], height=1024, width=1024
).images[0]
```
```
"""
"""
...
@@ -791,26 +786,6 @@ class StableDiffusionXLControlNetUnionPipeline(
...
@@ -791,26 +786,6 @@ class StableDiffusionXLControlNetUnionPipeline(
f
"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is
{
ip_adapter_image_embeds
[
0
].
ndim
}
D"
f
"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is
{
ip_adapter_image_embeds
[
0
].
ndim
}
D"
)
)
def
check_input
(
self
,
image
:
Union
[
ControlNetUnionInput
,
ControlNetUnionInputProMax
],
):
controlnet
=
self
.
controlnet
.
_orig_mod
if
is_compiled_module
(
self
.
controlnet
)
else
self
.
controlnet
if
not
isinstance
(
image
,
(
ControlNetUnionInput
,
ControlNetUnionInputProMax
)):
raise
ValueError
(
"Expected type of `image` to be one of `ControlNetUnionInput` or `ControlNetUnionInputProMax`"
)
if
len
(
image
)
!=
controlnet
.
config
.
num_control_type
:
if
isinstance
(
image
,
ControlNetUnionInput
):
raise
ValueError
(
f
"Expected num_control_type
{
controlnet
.
config
.
num_control_type
}
, got
{
len
(
image
)
}
. Try `ControlNetUnionInputProMax`."
)
elif
isinstance
(
image
,
ControlNetUnionInputProMax
):
raise
ValueError
(
f
"Expected num_control_type
{
controlnet
.
config
.
num_control_type
}
, got
{
len
(
image
)
}
. Try `ControlNetUnionInput`."
)
# Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image
# Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image
def
prepare_image
(
def
prepare_image
(
self
,
self
,
...
@@ -970,7 +945,7 @@ class StableDiffusionXLControlNetUnionPipeline(
...
@@ -970,7 +945,7 @@ class StableDiffusionXLControlNetUnionPipeline(
self
,
self
,
prompt
:
Union
[
str
,
List
[
str
]]
=
None
,
prompt
:
Union
[
str
,
List
[
str
]]
=
None
,
prompt_2
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
prompt_2
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
image
:
Union
[
ControlNetUnionInput
,
ControlNetUnionInputProMax
]
=
None
,
control_image
:
PipelineImageInput
=
None
,
height
:
Optional
[
int
]
=
None
,
height
:
Optional
[
int
]
=
None
,
width
:
Optional
[
int
]
=
None
,
width
:
Optional
[
int
]
=
None
,
num_inference_steps
:
int
=
50
,
num_inference_steps
:
int
=
50
,
...
@@ -997,6 +972,7 @@ class StableDiffusionXLControlNetUnionPipeline(
...
@@ -997,6 +972,7 @@ class StableDiffusionXLControlNetUnionPipeline(
guess_mode
:
bool
=
False
,
guess_mode
:
bool
=
False
,
control_guidance_start
:
Union
[
float
,
List
[
float
]]
=
0.0
,
control_guidance_start
:
Union
[
float
,
List
[
float
]]
=
0.0
,
control_guidance_end
:
Union
[
float
,
List
[
float
]]
=
1.0
,
control_guidance_end
:
Union
[
float
,
List
[
float
]]
=
1.0
,
control_mode
:
Optional
[
Union
[
int
,
List
[
int
]]]
=
None
,
original_size
:
Tuple
[
int
,
int
]
=
None
,
original_size
:
Tuple
[
int
,
int
]
=
None
,
crops_coords_top_left
:
Tuple
[
int
,
int
]
=
(
0
,
0
),
crops_coords_top_left
:
Tuple
[
int
,
int
]
=
(
0
,
0
),
target_size
:
Tuple
[
int
,
int
]
=
None
,
target_size
:
Tuple
[
int
,
int
]
=
None
,
...
@@ -1018,10 +994,7 @@ class StableDiffusionXLControlNetUnionPipeline(
...
@@ -1018,10 +994,7 @@ class StableDiffusionXLControlNetUnionPipeline(
prompt_2 (`str` or `List[str]`, *optional*):
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
used in both text-encoders.
used in both text-encoders.
image (`Union[ControlNetUnionInput, ControlNetUnionInputProMax]`):
control_image (`PipelineImageInput`):
In turn this supports (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`,
`List[PIL.Image.Image]`, `List[np.ndarray]`, `List[List[torch.FloatTensor]]`,
`List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted
specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted
as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
...
@@ -1168,38 +1141,45 @@ class StableDiffusionXLControlNetUnionPipeline(
...
@@ -1168,38 +1141,45 @@ class StableDiffusionXLControlNetUnionPipeline(
controlnet
=
self
.
controlnet
.
_orig_mod
if
is_compiled_module
(
self
.
controlnet
)
else
self
.
controlnet
controlnet
=
self
.
controlnet
.
_orig_mod
if
is_compiled_module
(
self
.
controlnet
)
else
self
.
controlnet
self
.
check_input
(
image
)
# align format for control guidance
# align format for control guidance
if
not
isinstance
(
control_guidance_start
,
list
)
and
isinstance
(
control_guidance_end
,
list
):
if
not
isinstance
(
control_guidance_start
,
list
)
and
isinstance
(
control_guidance_end
,
list
):
control_guidance_start
=
len
(
control_guidance_end
)
*
[
control_guidance_start
]
control_guidance_start
=
len
(
control_guidance_end
)
*
[
control_guidance_start
]
elif
not
isinstance
(
control_guidance_end
,
list
)
and
isinstance
(
control_guidance_start
,
list
):
elif
not
isinstance
(
control_guidance_end
,
list
)
and
isinstance
(
control_guidance_start
,
list
):
control_guidance_end
=
len
(
control_guidance_start
)
*
[
control_guidance_end
]
control_guidance_end
=
len
(
control_guidance_start
)
*
[
control_guidance_end
]
if
not
isinstance
(
control_image
,
list
):
control_image
=
[
control_image
]
if
not
isinstance
(
control_mode
,
list
):
control_mode
=
[
control_mode
]
if
len
(
control_image
)
!=
len
(
control_mode
):
raise
ValueError
(
"Expected len(control_image) == len(control_type)"
)
num_control_type
=
controlnet
.
config
.
num_control_type
# 1. Check inputs
control_type
=
[
0
for
_
in
range
(
num_control_type
)]
# 1. Check inputs. Raise error if not correct
# 1. Check inputs. Raise error if not correct
control_type
=
[]
for
_image
,
control_idx
in
zip
(
control_image
,
control_mode
):
for
image_type
in
image
:
control_type
[
control_idx
]
=
1
if
image
[
image_type
]:
self
.
check_inputs
(
self
.
check_inputs
(
prompt
,
prompt
,
prompt_2
,
prompt_2
,
_image
,
image
[
image_type
],
negative_prompt
,
negative_prompt
,
negative_prompt_2
,
negative_prompt_2
,
prompt_embeds
,
prompt_embeds
,
negative_prompt_embeds
,
negative_prompt_embeds
,
pooled_prompt_embeds
,
pooled_prompt_embeds
,
ip_adapter_image
,
ip_adapter_image
,
ip_adapter_image_embeds
,
ip_adapter_image_embeds
,
negative_pooled_prompt_embeds
,
negative_pooled_prompt_embeds
,
controlnet_conditioning_scale
,
controlnet_conditioning_scale
,
control_guidance_start
,
control_guidance_start
,
control_guidance_end
,
control_guidance_end
,
callback_on_step_end_tensor_inputs
,
callback_on_step_end_tensor_inputs
,
)
)
control_type
.
append
(
1
)
else
:
control_type
.
append
(
0
)
control_type
=
torch
.
Tensor
(
control_type
)
control_type
=
torch
.
Tensor
(
control_type
)
...
@@ -1258,20 +1238,19 @@ class StableDiffusionXLControlNetUnionPipeline(
...
@@ -1258,20 +1238,19 @@ class StableDiffusionXLControlNetUnionPipeline(
)
)
# 4. Prepare image
# 4. Prepare image
for
image_type
in
image
:
for
idx
,
_
in
enumerate
(
control_image
):
if
image
[
image_type
]:
control_image
[
idx
]
=
self
.
prepare_image
(
image
[
image_type
]
=
self
.
prepare_image
(
image
=
control_image
[
idx
],
image
=
image
[
image_type
],
width
=
width
,
width
=
width
,
height
=
height
,
height
=
height
,
batch_size
=
batch_size
*
num_images_per_prompt
,
batch_size
=
batch_size
*
num_images_per_prompt
,
num_images_per_prompt
=
num_images_per_prompt
,
num_images_per_prompt
=
num_images_per_prompt
,
device
=
device
,
device
=
device
,
dtype
=
controlnet
.
dtype
,
dtype
=
controlnet
.
dtype
,
do_classifier_free_guidance
=
self
.
do_classifier_free_guidance
,
do_classifier_free_guidance
=
self
.
do_classifier_free_guidance
,
guess_mode
=
guess_mode
,
guess_mode
=
guess_mode
,
)
)
height
,
width
=
control_image
[
idx
].
shape
[
-
2
:]
height
,
width
=
image
[
image_type
].
shape
[
-
2
:]
# 5. Prepare timesteps
# 5. Prepare timesteps
timesteps
,
num_inference_steps
=
retrieve_timesteps
(
timesteps
,
num_inference_steps
=
retrieve_timesteps
(
...
@@ -1312,11 +1291,11 @@ class StableDiffusionXLControlNetUnionPipeline(
...
@@ -1312,11 +1291,11 @@ class StableDiffusionXLControlNetUnionPipeline(
)
)
# 7.2 Prepare added time ids & embeddings
# 7.2 Prepare added time ids & embeddings
for
image_type
in
image
:
original_size
=
original_size
or
(
height
,
width
)
if
isinstance
(
image
[
image_type
],
torch
.
Tensor
):
original_size
=
original_size
or
image
[
image_type
].
shape
[
-
2
:]
target_size
=
target_size
or
(
height
,
width
)
target_size
=
target_size
or
(
height
,
width
)
for
_image
in
control_image
:
if
isinstance
(
_image
,
torch
.
Tensor
):
original_size
=
original_size
or
_image
.
shape
[
-
2
:]
add_text_embeds
=
pooled_prompt_embeds
add_text_embeds
=
pooled_prompt_embeds
if
self
.
text_encoder_2
is
None
:
if
self
.
text_encoder_2
is
None
:
text_encoder_projection_dim
=
int
(
pooled_prompt_embeds
.
shape
[
-
1
])
text_encoder_projection_dim
=
int
(
pooled_prompt_embeds
.
shape
[
-
1
])
...
@@ -1424,8 +1403,9 @@ class StableDiffusionXLControlNetUnionPipeline(
...
@@ -1424,8 +1403,9 @@ class StableDiffusionXLControlNetUnionPipeline(
control_model_input
,
control_model_input
,
t
,
t
,
encoder_hidden_states
=
controlnet_prompt_embeds
,
encoder_hidden_states
=
controlnet_prompt_embeds
,
controlnet_cond
=
image
,
controlnet_cond
=
control_
image
,
control_type
=
control_type
,
control_type
=
control_type
,
control_type_idx
=
control_mode
,
conditioning_scale
=
cond_scale
,
conditioning_scale
=
cond_scale
,
guess_mode
=
guess_mode
,
guess_mode
=
guess_mode
,
added_cond_kwargs
=
controlnet_added_cond_kwargs
,
added_cond_kwargs
=
controlnet_added_cond_kwargs
,
...
@@ -1478,7 +1458,6 @@ class StableDiffusionXLControlNetUnionPipeline(
...
@@ -1478,7 +1458,6 @@ class StableDiffusionXLControlNetUnionPipeline(
)
)
add_time_ids
=
callback_outputs
.
pop
(
"add_time_ids"
,
add_time_ids
)
add_time_ids
=
callback_outputs
.
pop
(
"add_time_ids"
,
add_time_ids
)
negative_add_time_ids
=
callback_outputs
.
pop
(
"negative_add_time_ids"
,
negative_add_time_ids
)
negative_add_time_ids
=
callback_outputs
.
pop
(
"negative_add_time_ids"
,
negative_add_time_ids
)
image
=
callback_outputs
.
pop
(
"image"
,
image
)
# call the callback, if provided
# call the callback, if provided
if
i
==
len
(
timesteps
)
-
1
or
((
i
+
1
)
>
num_warmup_steps
and
(
i
+
1
)
%
self
.
scheduler
.
order
==
0
):
if
i
==
len
(
timesteps
)
-
1
or
((
i
+
1
)
>
num_warmup_steps
and
(
i
+
1
)
%
self
.
scheduler
.
order
==
0
):
...
...
src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py
View file @
e8b65bff
...
@@ -43,7 +43,6 @@ from ...models.attention_processor import (
...
@@ -43,7 +43,6 @@ from ...models.attention_processor import (
AttnProcessor2_0
,
AttnProcessor2_0
,
XFormersAttnProcessor
,
XFormersAttnProcessor
,
)
)
from
...models.controlnets
import
ControlNetUnionInput
,
ControlNetUnionInputProMax
from
...models.lora
import
adjust_lora_scale_text_encoder
from
...models.lora
import
adjust_lora_scale_text_encoder
from
...schedulers
import
KarrasDiffusionSchedulers
from
...schedulers
import
KarrasDiffusionSchedulers
from
...utils
import
(
from
...utils
import
(
...
@@ -74,7 +73,6 @@ EXAMPLE_DOC_STRING = """
...
@@ -74,7 +73,6 @@ EXAMPLE_DOC_STRING = """
ControlNetUnionModel,
ControlNetUnionModel,
AutoencoderKL,
AutoencoderKL,
)
)
from diffusers.models.controlnets import ControlNetUnionInputProMax
from diffusers.utils import load_image
from diffusers.utils import load_image
import torch
import torch
from PIL import Image
from PIL import Image
...
@@ -95,6 +93,7 @@ EXAMPLE_DOC_STRING = """
...
@@ -95,6 +93,7 @@ EXAMPLE_DOC_STRING = """
controlnet=controlnet,
controlnet=controlnet,
vae=vae,
vae=vae,
torch_dtype=torch.float16,
torch_dtype=torch.float16,
variant="fp16",
).to("cuda")
).to("cuda")
# `enable_model_cpu_offload` is not recommended due to multiple generations
# `enable_model_cpu_offload` is not recommended due to multiple generations
height = image.height
height = image.height
...
@@ -132,14 +131,12 @@ EXAMPLE_DOC_STRING = """
...
@@ -132,14 +131,12 @@ EXAMPLE_DOC_STRING = """
# set ControlNetUnion input
# set ControlNetUnion input
result_images = []
result_images = []
for sub_img, crops_coords in zip(images, crops_coords_list):
for sub_img, crops_coords in zip(images, crops_coords_list):
union_input = ControlNetUnionInputProMax(
tile=sub_img,
)
new_width, new_height = W, H
new_width, new_height = W, H
out = pipe(
out = pipe(
prompt=[prompt] * 1,
prompt=[prompt] * 1,
image=sub_img,
image=sub_img,
control_image_list=union_input,
control_image=[sub_img],
control_mode=[6],
width=new_width,
width=new_width,
height=new_height,
height=new_height,
num_inference_steps=30,
num_inference_steps=30,
...
@@ -1065,7 +1062,7 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
...
@@ -1065,7 +1062,7 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
prompt
:
Union
[
str
,
List
[
str
]]
=
None
,
prompt
:
Union
[
str
,
List
[
str
]]
=
None
,
prompt_2
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
prompt_2
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
image
:
PipelineImageInput
=
None
,
image
:
PipelineImageInput
=
None
,
control_image
_list
:
Union
[
ControlNetUnionInput
,
ControlNetUnionInputProMax
]
=
None
,
control_image
:
PipelineImageInput
=
None
,
height
:
Optional
[
int
]
=
None
,
height
:
Optional
[
int
]
=
None
,
width
:
Optional
[
int
]
=
None
,
width
:
Optional
[
int
]
=
None
,
strength
:
float
=
0.8
,
strength
:
float
=
0.8
,
...
@@ -1090,6 +1087,7 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
...
@@ -1090,6 +1087,7 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
guess_mode
:
bool
=
False
,
guess_mode
:
bool
=
False
,
control_guidance_start
:
Union
[
float
,
List
[
float
]]
=
0.0
,
control_guidance_start
:
Union
[
float
,
List
[
float
]]
=
0.0
,
control_guidance_end
:
Union
[
float
,
List
[
float
]]
=
1.0
,
control_guidance_end
:
Union
[
float
,
List
[
float
]]
=
1.0
,
control_mode
:
Optional
[
Union
[
int
,
List
[
int
]]]
=
None
,
original_size
:
Tuple
[
int
,
int
]
=
None
,
original_size
:
Tuple
[
int
,
int
]
=
None
,
crops_coords_top_left
:
Tuple
[
int
,
int
]
=
(
0
,
0
),
crops_coords_top_left
:
Tuple
[
int
,
int
]
=
(
0
,
0
),
target_size
:
Tuple
[
int
,
int
]
=
None
,
target_size
:
Tuple
[
int
,
int
]
=
None
,
...
@@ -1119,10 +1117,7 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
...
@@ -1119,10 +1117,7 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
`List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
`List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
The initial image will be used as the starting point for the image generation process. Can also accept
The initial image will be used as the starting point for the image generation process. Can also accept
image latents as `image`, if passing latents directly, it will not be encoded again.
image latents as `image`, if passing latents directly, it will not be encoded again.
control_image_list (`Union[ControlNetUnionInput, ControlNetUnionInputProMax]`):
control_image (`PipelineImageInput`):
In turn this supports (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`,
`List[PIL.Image.Image]`, `List[np.ndarray]`, `List[List[torch.FloatTensor]]`,
`List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`)::
The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
the type is specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also
the type is specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also
be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height
be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height
...
@@ -1291,53 +1286,47 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
...
@@ -1291,53 +1286,47 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
controlnet
=
self
.
controlnet
.
_orig_mod
if
is_compiled_module
(
self
.
controlnet
)
else
self
.
controlnet
controlnet
=
self
.
controlnet
.
_orig_mod
if
is_compiled_module
(
self
.
controlnet
)
else
self
.
controlnet
if
not
isinstance
(
control_image_list
,
(
ControlNetUnionInput
,
ControlNetUnionInputProMax
)):
raise
ValueError
(
"Expected type of `control_image_list` to be one of `ControlNetUnionInput` or `ControlNetUnionInputProMax`"
)
if
len
(
control_image_list
)
!=
controlnet
.
config
.
num_control_type
:
if
isinstance
(
control_image_list
,
ControlNetUnionInput
):
raise
ValueError
(
f
"Expected num_control_type
{
controlnet
.
config
.
num_control_type
}
, got
{
len
(
control_image_list
)
}
. Try `ControlNetUnionInputProMax`."
)
elif
isinstance
(
control_image_list
,
ControlNetUnionInputProMax
):
raise
ValueError
(
f
"Expected num_control_type
{
controlnet
.
config
.
num_control_type
}
, got
{
len
(
control_image_list
)
}
. Try `ControlNetUnionInput`."
)
# align format for control guidance
# align format for control guidance
if
not
isinstance
(
control_guidance_start
,
list
)
and
isinstance
(
control_guidance_end
,
list
):
if
not
isinstance
(
control_guidance_start
,
list
)
and
isinstance
(
control_guidance_end
,
list
):
control_guidance_start
=
len
(
control_guidance_end
)
*
[
control_guidance_start
]
control_guidance_start
=
len
(
control_guidance_end
)
*
[
control_guidance_start
]
elif
not
isinstance
(
control_guidance_end
,
list
)
and
isinstance
(
control_guidance_start
,
list
):
elif
not
isinstance
(
control_guidance_end
,
list
)
and
isinstance
(
control_guidance_start
,
list
):
control_guidance_end
=
len
(
control_guidance_start
)
*
[
control_guidance_end
]
control_guidance_end
=
len
(
control_guidance_start
)
*
[
control_guidance_end
]
# 1. Check inputs. Raise error if not correct
if
not
isinstance
(
control_image
,
list
):
control_type
=
[]
control_image
=
[
control_image
]
for
image_type
in
control_image_list
:
if
control_image_list
[
image_type
]:
if
not
isinstance
(
control_mode
,
list
):
self
.
check_inputs
(
control_mode
=
[
control_mode
]
prompt
,
prompt_2
,
if
len
(
control_image
)
!=
len
(
control_mode
):
control_image_list
[
image_type
],
raise
ValueError
(
"Expected len(control_image) == len(control_type)"
)
strength
,
num_inference_steps
,
num_control_type
=
controlnet
.
config
.
num_control_type
callback_steps
,
negative_prompt
,
# 1. Check inputs
negative_prompt_2
,
control_type
=
[
0
for
_
in
range
(
num_control_type
)]
prompt_embeds
,
for
_image
,
control_idx
in
zip
(
control_image
,
control_mode
):
negative_prompt_embeds
,
control_type
[
control_idx
]
=
1
pooled_prompt_embeds
,
self
.
check_inputs
(
negative_pooled_prompt_embeds
,
prompt
,
ip_adapter_image
,
prompt_2
,
ip_adapter_image_embeds
,
_image
,
controlnet_conditioning_scale
,
strength
,
control_guidance_start
,
num_inference_steps
,
control_guidance_end
,
callback_steps
,
callback_on_step_end_tensor_inputs
,
negative_prompt
,
)
negative_prompt_2
,
control_type
.
append
(
1
)
prompt_embeds
,
else
:
negative_prompt_embeds
,
control_type
.
append
(
0
)
pooled_prompt_embeds
,
negative_pooled_prompt_embeds
,
ip_adapter_image
,
ip_adapter_image_embeds
,
controlnet_conditioning_scale
,
control_guidance_start
,
control_guidance_end
,
callback_on_step_end_tensor_inputs
,
)
control_type
=
torch
.
Tensor
(
control_type
)
control_type
=
torch
.
Tensor
(
control_type
)
...
@@ -1397,21 +1386,19 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
...
@@ -1397,21 +1386,19 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
# 4. Prepare image and controlnet_conditioning_image
# 4. Prepare image and controlnet_conditioning_image
image
=
self
.
image_processor
.
preprocess
(
image
,
height
=
height
,
width
=
width
).
to
(
dtype
=
torch
.
float32
)
image
=
self
.
image_processor
.
preprocess
(
image
,
height
=
height
,
width
=
width
).
to
(
dtype
=
torch
.
float32
)
for
image_type
in
control_image_list
:
for
idx
,
_
in
enumerate
(
control_image
):
if
control_image_list
[
image_type
]:
control_image
[
idx
]
=
self
.
prepare_control_image
(
control_image
=
self
.
prepare_control_image
(
image
=
control_image
[
idx
],
image
=
control_image_list
[
image_type
],
width
=
width
,
width
=
width
,
height
=
height
,
height
=
height
,
batch_size
=
batch_size
*
num_images_per_prompt
,
batch_size
=
batch_size
*
num_images_per_prompt
,
num_images_per_prompt
=
num_images_per_prompt
,
num_images_per_prompt
=
num_images_per_prompt
,
device
=
device
,
device
=
device
,
dtype
=
controlnet
.
dtype
,
dtype
=
controlnet
.
dtype
,
do_classifier_free_guidance
=
self
.
do_classifier_free_guidance
,
do_classifier_free_guidance
=
self
.
do_classifier_free_guidance
,
guess_mode
=
guess_mode
,
guess_mode
=
guess_mode
,
)
)
height
,
width
=
control_image
[
idx
].
shape
[
-
2
:]
height
,
width
=
control_image
.
shape
[
-
2
:]
control_image_list
[
image_type
]
=
control_image
# 5. Prepare timesteps
# 5. Prepare timesteps
self
.
scheduler
.
set_timesteps
(
num_inference_steps
,
device
=
device
)
self
.
scheduler
.
set_timesteps
(
num_inference_steps
,
device
=
device
)
...
@@ -1444,10 +1431,11 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
...
@@ -1444,10 +1431,11 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
)
)
# 7.2 Prepare added time ids & embeddings
# 7.2 Prepare added time ids & embeddings
for
image_type
in
control_image_list
:
original_size
=
original_size
or
(
height
,
width
)
if
isinstance
(
control_image_list
[
image_type
],
torch
.
Tensor
):
original_size
=
original_size
or
control_image_list
[
image_type
].
shape
[
-
2
:]
target_size
=
target_size
or
(
height
,
width
)
target_size
=
target_size
or
(
height
,
width
)
for
_image
in
control_image
:
if
isinstance
(
_image
,
torch
.
Tensor
):
original_size
=
original_size
or
_image
.
shape
[
-
2
:]
if
negative_original_size
is
None
:
if
negative_original_size
is
None
:
negative_original_size
=
original_size
negative_original_size
=
original_size
...
@@ -1531,8 +1519,9 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
...
@@ -1531,8 +1519,9 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
control_model_input
,
control_model_input
,
t
,
t
,
encoder_hidden_states
=
controlnet_prompt_embeds
,
encoder_hidden_states
=
controlnet_prompt_embeds
,
controlnet_cond
=
control_image
_list
,
controlnet_cond
=
control_image
,
control_type
=
control_type
,
control_type
=
control_type
,
control_type_idx
=
control_mode
,
conditioning_scale
=
cond_scale
,
conditioning_scale
=
cond_scale
,
guess_mode
=
guess_mode
,
guess_mode
=
guess_mode
,
added_cond_kwargs
=
controlnet_added_cond_kwargs
,
added_cond_kwargs
=
controlnet_added_cond_kwargs
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment