Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
e8b65bff
Unverified
Commit
e8b65bff
authored
Dec 12, 2024
by
hlky
Committed by
GitHub
Dec 12, 2024
Browse files
refactor StableDiffusionXLControlNetUnion (#10200)
mode
parent
f2d348d9
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
185 additions
and
310 deletions
+185
-310
src/diffusers/models/controlnets/__init__.py
src/diffusers/models/controlnets/__init__.py
+1
-1
src/diffusers/models/controlnets/controlnet_union.py
src/diffusers/models/controlnets/controlnet_union.py
+8
-93
src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py
...nes/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py
+60
-68
src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py
...s/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py
+56
-77
src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py
...nes/controlnet/pipeline_controlnet_union_sd_xl_img2img.py
+60
-71
No files found.
src/diffusers/models/controlnets/__init__.py
View file @
e8b65bff
...
...
@@ -15,7 +15,7 @@ if is_torch_available():
SparseControlNetModel
,
SparseControlNetOutput
,
)
from
.controlnet_union
import
ControlNetUnionInput
,
ControlNetUnionInputProMax
,
ControlNetUnionModel
from
.controlnet_union
import
ControlNetUnionModel
from
.controlnet_xs
import
ControlNetXSAdapter
,
ControlNetXSOutput
,
UNetControlNetXSModel
from
.multicontrolnet
import
MultiControlNetModel
...
...
src/diffusers/models/controlnets/controlnet_union.py
View file @
e8b65bff
...
...
@@ -11,14 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
torch
from
torch
import
nn
from
...configuration_utils
import
ConfigMixin
,
register_to_config
from
...image_processor
import
PipelineImageInput
from
...loaders.single_file_model
import
FromOriginalModelMixin
from
...utils
import
logging
from
..attention_processor
import
(
...
...
@@ -40,76 +38,6 @@ from ..unets.unet_2d_condition import UNet2DConditionModel
from
.controlnet
import
ControlNetConditioningEmbedding
,
ControlNetOutput
,
zero_module
@
dataclass
class
ControlNetUnionInput
:
"""
The image input of [`ControlNetUnionModel`]:
- 0: openpose
- 1: depth
- 2: hed/pidi/scribble/ted
- 3: canny/lineart/anime_lineart/mlsd
- 4: normal
- 5: segment
"""
openpose
:
Optional
[
PipelineImageInput
]
=
None
depth
:
Optional
[
PipelineImageInput
]
=
None
hed
:
Optional
[
PipelineImageInput
]
=
None
canny
:
Optional
[
PipelineImageInput
]
=
None
normal
:
Optional
[
PipelineImageInput
]
=
None
segment
:
Optional
[
PipelineImageInput
]
=
None
def
__len__
(
self
)
->
int
:
return
len
(
vars
(
self
))
def
__iter__
(
self
):
return
iter
(
vars
(
self
))
def
__getitem__
(
self
,
key
):
return
getattr
(
self
,
key
)
def
__setitem__
(
self
,
key
,
value
):
setattr
(
self
,
key
,
value
)
@
dataclass
class
ControlNetUnionInputProMax
:
"""
The image input of [`ControlNetUnionModel`]:
- 0: openpose
- 1: depth
- 2: hed/pidi/scribble/ted
- 3: canny/lineart/anime_lineart/mlsd
- 4: normal
- 5: segment
- 6: tile
- 7: repaint
"""
openpose
:
Optional
[
PipelineImageInput
]
=
None
depth
:
Optional
[
PipelineImageInput
]
=
None
hed
:
Optional
[
PipelineImageInput
]
=
None
canny
:
Optional
[
PipelineImageInput
]
=
None
normal
:
Optional
[
PipelineImageInput
]
=
None
segment
:
Optional
[
PipelineImageInput
]
=
None
tile
:
Optional
[
PipelineImageInput
]
=
None
repaint
:
Optional
[
PipelineImageInput
]
=
None
def
__len__
(
self
)
->
int
:
return
len
(
vars
(
self
))
def
__iter__
(
self
):
return
iter
(
vars
(
self
))
def
__getitem__
(
self
,
key
):
return
getattr
(
self
,
key
)
def
__setitem__
(
self
,
key
,
value
):
setattr
(
self
,
key
,
value
)
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
...
...
@@ -680,8 +608,9 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
sample
:
torch
.
Tensor
,
timestep
:
Union
[
torch
.
Tensor
,
float
,
int
],
encoder_hidden_states
:
torch
.
Tensor
,
controlnet_cond
:
Union
[
ControlNetUnionInput
,
ControlNetUnionInputProMax
],
controlnet_cond
:
List
[
torch
.
Tensor
],
control_type
:
torch
.
Tensor
,
control_type_idx
:
List
[
int
],
conditioning_scale
:
float
=
1.0
,
class_labels
:
Optional
[
torch
.
Tensor
]
=
None
,
timestep_cond
:
Optional
[
torch
.
Tensor
]
=
None
,
...
...
@@ -701,11 +630,13 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
The number of timesteps to denoise an input.
encoder_hidden_states (`torch.Tensor`):
The encoder hidden states.
controlnet_cond (`
Union[ControlNetUnionInput, ControlNetUnionInputProMax
]`):
controlnet_cond (`
List[torch.Tensor
]`):
The conditional input tensors.
control_type (`torch.Tensor`):
A tensor of shape `(batch, num_control_type)` with values `0` or `1` depending on whether the control
type is used.
control_type_idx (`List[int]`):
The indices of `control_type`.
conditioning_scale (`float`, defaults to `1.0`):
The scale factor for ControlNet outputs.
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
...
...
@@ -733,20 +664,6 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
returned where the first element is the sample tensor.
"""
if
not
isinstance
(
controlnet_cond
,
(
ControlNetUnionInput
,
ControlNetUnionInputProMax
)):
raise
ValueError
(
"Expected type of `controlnet_cond` to be one of `ControlNetUnionInput` or `ControlNetUnionInputProMax`"
)
if
len
(
controlnet_cond
)
!=
self
.
config
.
num_control_type
:
if
isinstance
(
controlnet_cond
,
ControlNetUnionInput
):
raise
ValueError
(
f
"Expected num_control_type
{
self
.
config
.
num_control_type
}
, got
{
len
(
controlnet_cond
)
}
. Try `ControlNetUnionInputProMax`."
)
elif
isinstance
(
controlnet_cond
,
ControlNetUnionInputProMax
):
raise
ValueError
(
f
"Expected num_control_type
{
self
.
config
.
num_control_type
}
, got
{
len
(
controlnet_cond
)
}
. Try `ControlNetUnionInput`."
)
# check channel order
channel_order
=
self
.
config
.
controlnet_conditioning_channel_order
...
...
@@ -830,12 +747,10 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
inputs
=
[]
condition_list
=
[]
for
idx
,
image_type
in
enumerate
(
controlnet_cond
):
if
controlnet_cond
[
image_type
]
is
None
:
continue
condition
=
self
.
controlnet_cond_embedding
(
controlnet_cond
[
image_type
])
for
cond
,
control_idx
in
zip
(
controlnet_cond
,
control_type_idx
):
condition
=
self
.
controlnet_cond_embedding
(
cond
)
feat_seq
=
torch
.
mean
(
condition
,
dim
=
(
2
,
3
))
feat_seq
=
feat_seq
+
self
.
task_embedding
[
idx
]
feat_seq
=
feat_seq
+
self
.
task_embedding
[
control_
idx
]
inputs
.
append
(
feat_seq
.
unsqueeze
(
1
))
condition_list
.
append
(
condition
)
...
...
src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py
View file @
e8b65bff
...
...
@@ -40,7 +40,6 @@ from ...models.attention_processor import (
AttnProcessor2_0
,
XFormersAttnProcessor
,
)
from
...models.controlnets
import
ControlNetUnionInput
,
ControlNetUnionInputProMax
from
...models.lora
import
adjust_lora_scale_text_encoder
from
...schedulers
import
KarrasDiffusionSchedulers
from
...utils
import
(
...
...
@@ -82,7 +81,6 @@ EXAMPLE_DOC_STRING = """
Examples:
```py
from diffusers import StableDiffusionXLControlNetUnionInpaintPipeline, ControlNetUnionModel, AutoencoderKL
from diffusers.models.controlnets import ControlNetUnionInputProMax
from diffusers.utils import load_image
import torch
import numpy as np
...
...
@@ -114,11 +112,8 @@ EXAMPLE_DOC_STRING = """
mask_np = np.array(mask)
controlnet_img_np[mask_np > 0] = 0
controlnet_img = Image.fromarray(controlnet_img_np)
union_input = ControlNetUnionInputProMax(
repaint=controlnet_img,
)
# generate image
image = pipe(prompt, image=image, mask_image=mask, control_image
_list=union_input
).images[0]
image = pipe(prompt, image=image, mask_image=mask, control_image
=[controlnet_img], control_mode=[7]
).images[0]
image.save("inpaint.png")
```
"""
...
...
@@ -1130,7 +1125,7 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
prompt_2
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
image
:
PipelineImageInput
=
None
,
mask_image
:
PipelineImageInput
=
None
,
control_image
_list
:
Union
[
ControlNetUnionInput
,
ControlNetUnionInputProMax
]
=
None
,
control_image
:
PipelineImageInput
=
None
,
height
:
Optional
[
int
]
=
None
,
width
:
Optional
[
int
]
=
None
,
padding_mask_crop
:
Optional
[
int
]
=
None
,
...
...
@@ -1158,6 +1153,7 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
guess_mode
:
bool
=
False
,
control_guidance_start
:
Union
[
float
,
List
[
float
]]
=
0.0
,
control_guidance_end
:
Union
[
float
,
List
[
float
]]
=
1.0
,
control_mode
:
Optional
[
Union
[
int
,
List
[
int
]]]
=
None
,
guidance_rescale
:
float
=
0.0
,
original_size
:
Tuple
[
int
,
int
]
=
None
,
crops_coords_top_left
:
Tuple
[
int
,
int
]
=
(
0
,
0
),
...
...
@@ -1345,20 +1341,6 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
controlnet
=
self
.
controlnet
.
_orig_mod
if
is_compiled_module
(
self
.
controlnet
)
else
self
.
controlnet
if
not
isinstance
(
control_image_list
,
(
ControlNetUnionInput
,
ControlNetUnionInputProMax
)):
raise
ValueError
(
"Expected type of `control_image_list` to be one of `ControlNetUnionInput` or `ControlNetUnionInputProMax`"
)
if
len
(
control_image_list
)
!=
controlnet
.
config
.
num_control_type
:
if
isinstance
(
control_image_list
,
ControlNetUnionInput
):
raise
ValueError
(
f
"Expected num_control_type
{
controlnet
.
config
.
num_control_type
}
, got
{
len
(
control_image_list
)
}
. Try `ControlNetUnionInputProMax`."
)
elif
isinstance
(
control_image_list
,
ControlNetUnionInputProMax
):
raise
ValueError
(
f
"Expected num_control_type
{
controlnet
.
config
.
num_control_type
}
, got
{
len
(
control_image_list
)
}
. Try `ControlNetUnionInput`."
)
# align format for control guidance
if
not
isinstance
(
control_guidance_start
,
list
)
and
isinstance
(
control_guidance_end
,
list
):
control_guidance_start
=
len
(
control_guidance_end
)
*
[
control_guidance_start
]
...
...
@@ -1375,36 +1357,44 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
elif
not
isinstance
(
control_guidance_end
,
list
)
and
isinstance
(
control_guidance_start
,
list
):
control_guidance_end
=
len
(
control_guidance_start
)
*
[
control_guidance_end
]
if
not
isinstance
(
control_image
,
list
):
control_image
=
[
control_image
]
if
not
isinstance
(
control_mode
,
list
):
control_mode
=
[
control_mode
]
if
len
(
control_image
)
!=
len
(
control_mode
):
raise
ValueError
(
"Expected len(control_image) == len(control_type)"
)
num_control_type
=
controlnet
.
config
.
num_control_type
# 1. Check inputs
control_type
=
[]
for
image_type
in
control_image_list
:
if
control_image_list
[
image_type
]:
self
.
check_inputs
(
prompt
,
prompt_2
,
control_image_list
[
image_type
],
mask_image
,
strength
,
num_inference_steps
,
callback_steps
,
output_type
,
negative_prompt
,
negative_prompt_2
,
prompt_embeds
,
negative_prompt_embeds
,
ip_adapter_image
,
ip_adapter_image_embeds
,
pooled_prompt_embeds
,
negative_pooled_prompt_embeds
,
controlnet_conditioning_scale
,
control_guidance_start
,
control_guidance_end
,
callback_on_step_end_tensor_inputs
,
padding_mask_crop
,
)
control_type
.
append
(
1
)
else
:
control_type
.
append
(
0
)
control_type
=
[
0
for
_
in
range
(
num_control_type
)]
for
_image
,
control_idx
in
zip
(
control_image
,
control_mode
):
control_type
[
control_idx
]
=
1
self
.
check_inputs
(
prompt
,
prompt_2
,
_image
,
mask_image
,
strength
,
num_inference_steps
,
callback_steps
,
output_type
,
negative_prompt
,
negative_prompt_2
,
prompt_embeds
,
negative_prompt_embeds
,
ip_adapter_image
,
ip_adapter_image_embeds
,
pooled_prompt_embeds
,
negative_pooled_prompt_embeds
,
controlnet_conditioning_scale
,
control_guidance_start
,
control_guidance_end
,
callback_on_step_end_tensor_inputs
,
padding_mask_crop
,
)
control_type
=
torch
.
Tensor
(
control_type
)
...
...
@@ -1499,23 +1489,21 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
init_image
=
init_image
.
to
(
dtype
=
torch
.
float32
)
# 5.2 Prepare control images
for
image_type
in
control_image_list
:
if
control_image_list
[
image_type
]:
control_image
=
self
.
prepare_control_image
(
image
=
control_image_list
[
image_type
],
width
=
width
,
height
=
height
,
batch_size
=
batch_size
*
num_images_per_prompt
,
num_images_per_prompt
=
num_images_per_prompt
,
device
=
device
,
dtype
=
controlnet
.
dtype
,
crops_coords
=
crops_coords
,
resize_mode
=
resize_mode
,
do_classifier_free_guidance
=
self
.
do_classifier_free_guidance
,
guess_mode
=
guess_mode
,
)
height
,
width
=
control_image
.
shape
[
-
2
:]
control_image_list
[
image_type
]
=
control_image
for
idx
,
_
in
enumerate
(
control_image
):
control_image
[
idx
]
=
self
.
prepare_control_image
(
image
=
control_image
[
idx
],
width
=
width
,
height
=
height
,
batch_size
=
batch_size
*
num_images_per_prompt
,
num_images_per_prompt
=
num_images_per_prompt
,
device
=
device
,
dtype
=
controlnet
.
dtype
,
crops_coords
=
crops_coords
,
resize_mode
=
resize_mode
,
do_classifier_free_guidance
=
self
.
do_classifier_free_guidance
,
guess_mode
=
guess_mode
,
)
height
,
width
=
control_image
[
idx
].
shape
[
-
2
:]
# 5.3 Prepare mask
mask
=
self
.
mask_processor
.
preprocess
(
...
...
@@ -1589,6 +1577,9 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
original_size
=
original_size
or
(
height
,
width
)
target_size
=
target_size
or
(
height
,
width
)
for
_image
in
control_image
:
if
isinstance
(
_image
,
torch
.
Tensor
):
original_size
=
original_size
or
_image
.
shape
[
-
2
:]
# 10. Prepare added time ids & embeddings
add_text_embeds
=
pooled_prompt_embeds
...
...
@@ -1693,8 +1684,9 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
control_model_input
,
t
,
encoder_hidden_states
=
controlnet_prompt_embeds
,
controlnet_cond
=
control_image
_list
,
controlnet_cond
=
control_image
,
control_type
=
control_type
,
control_type_idx
=
control_mode
,
conditioning_scale
=
cond_scale
,
guess_mode
=
guess_mode
,
added_cond_kwargs
=
controlnet_added_cond_kwargs
,
...
...
src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py
View file @
e8b65bff
...
...
@@ -43,7 +43,6 @@ from ...models.attention_processor import (
AttnProcessor2_0
,
XFormersAttnProcessor
,
)
from
...models.controlnets
import
ControlNetUnionInput
,
ControlNetUnionInputProMax
from
...models.lora
import
adjust_lora_scale_text_encoder
from
...schedulers
import
KarrasDiffusionSchedulers
from
...utils
import
(
...
...
@@ -70,7 +69,6 @@ EXAMPLE_DOC_STRING = """
>>> # !pip install controlnet_aux
>>> from controlnet_aux import LineartAnimeDetector
>>> from diffusers import StableDiffusionXLControlNetUnionPipeline, ControlNetUnionModel, AutoencoderKL
>>> from diffusers.models.controlnets import ControlNetUnionInput
>>> from diffusers.utils import load_image
>>> import torch
...
...
@@ -89,17 +87,14 @@ EXAMPLE_DOC_STRING = """
... controlnet=controlnet,
... vae=vae,
... torch_dtype=torch.float16,
... variant="fp16",
... )
>>> pipe.enable_model_cpu_offload()
>>> # prepare image
>>> processor = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators")
>>> controlnet_img = processor(image, output_type="pil")
>>> # set ControlNetUnion input
>>> union_input = ControlNetUnionInput(
... canny=controlnet_img,
... )
>>> # generate image
>>> image = pipe(prompt,
image=union_input
).images[0]
>>> image = pipe(prompt,
control_image=[controlnet_img], control_mode=[3], height=1024, width=1024
).images[0]
```
"""
...
...
@@ -791,26 +786,6 @@ class StableDiffusionXLControlNetUnionPipeline(
f
"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is
{
ip_adapter_image_embeds
[
0
].
ndim
}
D"
)
def
check_input
(
self
,
image
:
Union
[
ControlNetUnionInput
,
ControlNetUnionInputProMax
],
):
controlnet
=
self
.
controlnet
.
_orig_mod
if
is_compiled_module
(
self
.
controlnet
)
else
self
.
controlnet
if
not
isinstance
(
image
,
(
ControlNetUnionInput
,
ControlNetUnionInputProMax
)):
raise
ValueError
(
"Expected type of `image` to be one of `ControlNetUnionInput` or `ControlNetUnionInputProMax`"
)
if
len
(
image
)
!=
controlnet
.
config
.
num_control_type
:
if
isinstance
(
image
,
ControlNetUnionInput
):
raise
ValueError
(
f
"Expected num_control_type
{
controlnet
.
config
.
num_control_type
}
, got
{
len
(
image
)
}
. Try `ControlNetUnionInputProMax`."
)
elif
isinstance
(
image
,
ControlNetUnionInputProMax
):
raise
ValueError
(
f
"Expected num_control_type
{
controlnet
.
config
.
num_control_type
}
, got
{
len
(
image
)
}
. Try `ControlNetUnionInput`."
)
# Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image
def
prepare_image
(
self
,
...
...
@@ -970,7 +945,7 @@ class StableDiffusionXLControlNetUnionPipeline(
self
,
prompt
:
Union
[
str
,
List
[
str
]]
=
None
,
prompt_2
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
image
:
Union
[
ControlNetUnionInput
,
ControlNetUnionInputProMax
]
=
None
,
control_image
:
PipelineImageInput
=
None
,
height
:
Optional
[
int
]
=
None
,
width
:
Optional
[
int
]
=
None
,
num_inference_steps
:
int
=
50
,
...
...
@@ -997,6 +972,7 @@ class StableDiffusionXLControlNetUnionPipeline(
guess_mode
:
bool
=
False
,
control_guidance_start
:
Union
[
float
,
List
[
float
]]
=
0.0
,
control_guidance_end
:
Union
[
float
,
List
[
float
]]
=
1.0
,
control_mode
:
Optional
[
Union
[
int
,
List
[
int
]]]
=
None
,
original_size
:
Tuple
[
int
,
int
]
=
None
,
crops_coords_top_left
:
Tuple
[
int
,
int
]
=
(
0
,
0
),
target_size
:
Tuple
[
int
,
int
]
=
None
,
...
...
@@ -1018,10 +994,7 @@ class StableDiffusionXLControlNetUnionPipeline(
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
used in both text-encoders.
image (`Union[ControlNetUnionInput, ControlNetUnionInputProMax]`):
In turn this supports (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`,
`List[PIL.Image.Image]`, `List[np.ndarray]`, `List[List[torch.FloatTensor]]`,
`List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
control_image (`PipelineImageInput`):
The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted
as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
...
...
@@ -1168,38 +1141,45 @@ class StableDiffusionXLControlNetUnionPipeline(
controlnet
=
self
.
controlnet
.
_orig_mod
if
is_compiled_module
(
self
.
controlnet
)
else
self
.
controlnet
self
.
check_input
(
image
)
# align format for control guidance
if
not
isinstance
(
control_guidance_start
,
list
)
and
isinstance
(
control_guidance_end
,
list
):
control_guidance_start
=
len
(
control_guidance_end
)
*
[
control_guidance_start
]
elif
not
isinstance
(
control_guidance_end
,
list
)
and
isinstance
(
control_guidance_start
,
list
):
control_guidance_end
=
len
(
control_guidance_start
)
*
[
control_guidance_end
]
if
not
isinstance
(
control_image
,
list
):
control_image
=
[
control_image
]
if
not
isinstance
(
control_mode
,
list
):
control_mode
=
[
control_mode
]
if
len
(
control_image
)
!=
len
(
control_mode
):
raise
ValueError
(
"Expected len(control_image) == len(control_type)"
)
num_control_type
=
controlnet
.
config
.
num_control_type
# 1. Check inputs
control_type
=
[
0
for
_
in
range
(
num_control_type
)]
# 1. Check inputs. Raise error if not correct
control_type
=
[]
for
image_type
in
image
:
if
image
[
image_type
]:
self
.
check_inputs
(
prompt
,
prompt_2
,
image
[
image_type
],
negative_prompt
,
negative_prompt_2
,
prompt_embeds
,
negative_prompt_embeds
,
pooled_prompt_embeds
,
ip_adapter_image
,
ip_adapter_image_embeds
,
negative_pooled_prompt_embeds
,
controlnet_conditioning_scale
,
control_guidance_start
,
control_guidance_end
,
callback_on_step_end_tensor_inputs
,
)
control_type
.
append
(
1
)
else
:
control_type
.
append
(
0
)
for
_image
,
control_idx
in
zip
(
control_image
,
control_mode
):
control_type
[
control_idx
]
=
1
self
.
check_inputs
(
prompt
,
prompt_2
,
_image
,
negative_prompt
,
negative_prompt_2
,
prompt_embeds
,
negative_prompt_embeds
,
pooled_prompt_embeds
,
ip_adapter_image
,
ip_adapter_image_embeds
,
negative_pooled_prompt_embeds
,
controlnet_conditioning_scale
,
control_guidance_start
,
control_guidance_end
,
callback_on_step_end_tensor_inputs
,
)
control_type
=
torch
.
Tensor
(
control_type
)
...
...
@@ -1258,20 +1238,19 @@ class StableDiffusionXLControlNetUnionPipeline(
)
# 4. Prepare image
for
image_type
in
image
:
if
image
[
image_type
]:
image
[
image_type
]
=
self
.
prepare_image
(
image
=
image
[
image_type
],
width
=
width
,
height
=
height
,
batch_size
=
batch_size
*
num_images_per_prompt
,
num_images_per_prompt
=
num_images_per_prompt
,
device
=
device
,
dtype
=
controlnet
.
dtype
,
do_classifier_free_guidance
=
self
.
do_classifier_free_guidance
,
guess_mode
=
guess_mode
,
)
height
,
width
=
image
[
image_type
].
shape
[
-
2
:]
for
idx
,
_
in
enumerate
(
control_image
):
control_image
[
idx
]
=
self
.
prepare_image
(
image
=
control_image
[
idx
],
width
=
width
,
height
=
height
,
batch_size
=
batch_size
*
num_images_per_prompt
,
num_images_per_prompt
=
num_images_per_prompt
,
device
=
device
,
dtype
=
controlnet
.
dtype
,
do_classifier_free_guidance
=
self
.
do_classifier_free_guidance
,
guess_mode
=
guess_mode
,
)
height
,
width
=
control_image
[
idx
].
shape
[
-
2
:]
# 5. Prepare timesteps
timesteps
,
num_inference_steps
=
retrieve_timesteps
(
...
...
@@ -1312,11 +1291,11 @@ class StableDiffusionXLControlNetUnionPipeline(
)
# 7.2 Prepare added time ids & embeddings
for
image_type
in
image
:
if
isinstance
(
image
[
image_type
],
torch
.
Tensor
):
original_size
=
original_size
or
image
[
image_type
].
shape
[
-
2
:]
original_size
=
original_size
or
(
height
,
width
)
target_size
=
target_size
or
(
height
,
width
)
for
_image
in
control_image
:
if
isinstance
(
_image
,
torch
.
Tensor
):
original_size
=
original_size
or
_image
.
shape
[
-
2
:]
add_text_embeds
=
pooled_prompt_embeds
if
self
.
text_encoder_2
is
None
:
text_encoder_projection_dim
=
int
(
pooled_prompt_embeds
.
shape
[
-
1
])
...
...
@@ -1424,8 +1403,9 @@ class StableDiffusionXLControlNetUnionPipeline(
control_model_input
,
t
,
encoder_hidden_states
=
controlnet_prompt_embeds
,
controlnet_cond
=
image
,
controlnet_cond
=
control_
image
,
control_type
=
control_type
,
control_type_idx
=
control_mode
,
conditioning_scale
=
cond_scale
,
guess_mode
=
guess_mode
,
added_cond_kwargs
=
controlnet_added_cond_kwargs
,
...
...
@@ -1478,7 +1458,6 @@ class StableDiffusionXLControlNetUnionPipeline(
)
add_time_ids
=
callback_outputs
.
pop
(
"add_time_ids"
,
add_time_ids
)
negative_add_time_ids
=
callback_outputs
.
pop
(
"negative_add_time_ids"
,
negative_add_time_ids
)
image
=
callback_outputs
.
pop
(
"image"
,
image
)
# call the callback, if provided
if
i
==
len
(
timesteps
)
-
1
or
((
i
+
1
)
>
num_warmup_steps
and
(
i
+
1
)
%
self
.
scheduler
.
order
==
0
):
...
...
src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py
View file @
e8b65bff
...
...
@@ -43,7 +43,6 @@ from ...models.attention_processor import (
AttnProcessor2_0
,
XFormersAttnProcessor
,
)
from
...models.controlnets
import
ControlNetUnionInput
,
ControlNetUnionInputProMax
from
...models.lora
import
adjust_lora_scale_text_encoder
from
...schedulers
import
KarrasDiffusionSchedulers
from
...utils
import
(
...
...
@@ -74,7 +73,6 @@ EXAMPLE_DOC_STRING = """
ControlNetUnionModel,
AutoencoderKL,
)
from diffusers.models.controlnets import ControlNetUnionInputProMax
from diffusers.utils import load_image
import torch
from PIL import Image
...
...
@@ -95,6 +93,7 @@ EXAMPLE_DOC_STRING = """
controlnet=controlnet,
vae=vae,
torch_dtype=torch.float16,
variant="fp16",
).to("cuda")
# `enable_model_cpu_offload` is not recommended due to multiple generations
height = image.height
...
...
@@ -132,14 +131,12 @@ EXAMPLE_DOC_STRING = """
# set ControlNetUnion input
result_images = []
for sub_img, crops_coords in zip(images, crops_coords_list):
union_input = ControlNetUnionInputProMax(
tile=sub_img,
)
new_width, new_height = W, H
out = pipe(
prompt=[prompt] * 1,
image=sub_img,
control_image_list=union_input,
control_image=[sub_img],
control_mode=[6],
width=new_width,
height=new_height,
num_inference_steps=30,
...
...
@@ -1065,7 +1062,7 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
prompt
:
Union
[
str
,
List
[
str
]]
=
None
,
prompt_2
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
image
:
PipelineImageInput
=
None
,
control_image
_list
:
Union
[
ControlNetUnionInput
,
ControlNetUnionInputProMax
]
=
None
,
control_image
:
PipelineImageInput
=
None
,
height
:
Optional
[
int
]
=
None
,
width
:
Optional
[
int
]
=
None
,
strength
:
float
=
0.8
,
...
...
@@ -1090,6 +1087,7 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
guess_mode
:
bool
=
False
,
control_guidance_start
:
Union
[
float
,
List
[
float
]]
=
0.0
,
control_guidance_end
:
Union
[
float
,
List
[
float
]]
=
1.0
,
control_mode
:
Optional
[
Union
[
int
,
List
[
int
]]]
=
None
,
original_size
:
Tuple
[
int
,
int
]
=
None
,
crops_coords_top_left
:
Tuple
[
int
,
int
]
=
(
0
,
0
),
target_size
:
Tuple
[
int
,
int
]
=
None
,
...
...
@@ -1119,10 +1117,7 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
`List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
The initial image will be used as the starting point for the image generation process. Can also accept
image latents as `image`, if passing latents directly, it will not be encoded again.
control_image_list (`Union[ControlNetUnionInput, ControlNetUnionInputProMax]`):
In turn this supports (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`,
`List[PIL.Image.Image]`, `List[np.ndarray]`, `List[List[torch.FloatTensor]]`,
`List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`)::
control_image (`PipelineImageInput`):
The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
the type is specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also
be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height
...
...
@@ -1291,53 +1286,47 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
controlnet
=
self
.
controlnet
.
_orig_mod
if
is_compiled_module
(
self
.
controlnet
)
else
self
.
controlnet
if
not
isinstance
(
control_image_list
,
(
ControlNetUnionInput
,
ControlNetUnionInputProMax
)):
raise
ValueError
(
"Expected type of `control_image_list` to be one of `ControlNetUnionInput` or `ControlNetUnionInputProMax`"
)
if
len
(
control_image_list
)
!=
controlnet
.
config
.
num_control_type
:
if
isinstance
(
control_image_list
,
ControlNetUnionInput
):
raise
ValueError
(
f
"Expected num_control_type
{
controlnet
.
config
.
num_control_type
}
, got
{
len
(
control_image_list
)
}
. Try `ControlNetUnionInputProMax`."
)
elif
isinstance
(
control_image_list
,
ControlNetUnionInputProMax
):
raise
ValueError
(
f
"Expected num_control_type
{
controlnet
.
config
.
num_control_type
}
, got
{
len
(
control_image_list
)
}
. Try `ControlNetUnionInput`."
)
# align format for control guidance
if
not
isinstance
(
control_guidance_start
,
list
)
and
isinstance
(
control_guidance_end
,
list
):
control_guidance_start
=
len
(
control_guidance_end
)
*
[
control_guidance_start
]
elif
not
isinstance
(
control_guidance_end
,
list
)
and
isinstance
(
control_guidance_start
,
list
):
control_guidance_end
=
len
(
control_guidance_start
)
*
[
control_guidance_end
]
# 1. Check inputs. Raise error if not correct
control_type
=
[]
for
image_type
in
control_image_list
:
if
control_image_list
[
image_type
]:
self
.
check_inputs
(
prompt
,
prompt_2
,
control_image_list
[
image_type
],
strength
,
num_inference_steps
,
callback_steps
,
negative_prompt
,
negative_prompt_2
,
prompt_embeds
,
negative_prompt_embeds
,
pooled_prompt_embeds
,
negative_pooled_prompt_embeds
,
ip_adapter_image
,
ip_adapter_image_embeds
,
controlnet_conditioning_scale
,
control_guidance_start
,
control_guidance_end
,
callback_on_step_end_tensor_inputs
,
)
control_type
.
append
(
1
)
else
:
control_type
.
append
(
0
)
if
not
isinstance
(
control_image
,
list
):
control_image
=
[
control_image
]
if
not
isinstance
(
control_mode
,
list
):
control_mode
=
[
control_mode
]
if
len
(
control_image
)
!=
len
(
control_mode
):
raise
ValueError
(
"Expected len(control_image) == len(control_type)"
)
num_control_type
=
controlnet
.
config
.
num_control_type
# 1. Check inputs
control_type
=
[
0
for
_
in
range
(
num_control_type
)]
for
_image
,
control_idx
in
zip
(
control_image
,
control_mode
):
control_type
[
control_idx
]
=
1
self
.
check_inputs
(
prompt
,
prompt_2
,
_image
,
strength
,
num_inference_steps
,
callback_steps
,
negative_prompt
,
negative_prompt_2
,
prompt_embeds
,
negative_prompt_embeds
,
pooled_prompt_embeds
,
negative_pooled_prompt_embeds
,
ip_adapter_image
,
ip_adapter_image_embeds
,
controlnet_conditioning_scale
,
control_guidance_start
,
control_guidance_end
,
callback_on_step_end_tensor_inputs
,
)
control_type
=
torch
.
Tensor
(
control_type
)
...
...
@@ -1397,21 +1386,19 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
# 4. Prepare image and controlnet_conditioning_image
image
=
self
.
image_processor
.
preprocess
(
image
,
height
=
height
,
width
=
width
).
to
(
dtype
=
torch
.
float32
)
for
image_type
in
control_image_list
:
if
control_image_list
[
image_type
]:
control_image
=
self
.
prepare_control_image
(
image
=
control_image_list
[
image_type
],
width
=
width
,
height
=
height
,
batch_size
=
batch_size
*
num_images_per_prompt
,
num_images_per_prompt
=
num_images_per_prompt
,
device
=
device
,
dtype
=
controlnet
.
dtype
,
do_classifier_free_guidance
=
self
.
do_classifier_free_guidance
,
guess_mode
=
guess_mode
,
)
height
,
width
=
control_image
.
shape
[
-
2
:]
control_image_list
[
image_type
]
=
control_image
for
idx
,
_
in
enumerate
(
control_image
):
control_image
[
idx
]
=
self
.
prepare_control_image
(
image
=
control_image
[
idx
],
width
=
width
,
height
=
height
,
batch_size
=
batch_size
*
num_images_per_prompt
,
num_images_per_prompt
=
num_images_per_prompt
,
device
=
device
,
dtype
=
controlnet
.
dtype
,
do_classifier_free_guidance
=
self
.
do_classifier_free_guidance
,
guess_mode
=
guess_mode
,
)
height
,
width
=
control_image
[
idx
].
shape
[
-
2
:]
# 5. Prepare timesteps
self
.
scheduler
.
set_timesteps
(
num_inference_steps
,
device
=
device
)
...
...
@@ -1444,10 +1431,11 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
)
# 7.2 Prepare added time ids & embeddings
for
image_type
in
control_image_list
:
if
isinstance
(
control_image_list
[
image_type
],
torch
.
Tensor
):
original_size
=
original_size
or
control_image_list
[
image_type
].
shape
[
-
2
:]
original_size
=
original_size
or
(
height
,
width
)
target_size
=
target_size
or
(
height
,
width
)
for
_image
in
control_image
:
if
isinstance
(
_image
,
torch
.
Tensor
):
original_size
=
original_size
or
_image
.
shape
[
-
2
:]
if
negative_original_size
is
None
:
negative_original_size
=
original_size
...
...
@@ -1531,8 +1519,9 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
control_model_input
,
t
,
encoder_hidden_states
=
controlnet_prompt_embeds
,
controlnet_cond
=
control_image
_list
,
controlnet_cond
=
control_image
,
control_type
=
control_type
,
control_type_idx
=
control_mode
,
conditioning_scale
=
cond_scale
,
guess_mode
=
guess_mode
,
added_cond_kwargs
=
controlnet_added_cond_kwargs
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment