Unverified Commit 58fc8244 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

add: better warning messages when handling multiple conditionings. (#2804)

* add: better warning messages when handling multiple conditioning.

* fix: handling of controlnet_conditioning_scale
parent fab4f3d6
...@@ -537,15 +537,27 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline): ...@@ -537,15 +537,27 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline):
f" {negative_prompt_embeds.shape}." f" {negative_prompt_embeds.shape}."
) )
# Check `image` # `prompt` needs more sophisticated handling when there are multiple
# conditionings.
if isinstance(self.controlnet, MultiControlNetModel):
if isinstance(prompt, list):
logger.warning(
f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}"
" prompts. The conditionings will be fixed across the prompts."
)
# Check `image`
if isinstance(self.controlnet, ControlNetModel): if isinstance(self.controlnet, ControlNetModel):
self.check_image(image, prompt, prompt_embeds) self.check_image(image, prompt, prompt_embeds)
elif isinstance(self.controlnet, MultiControlNetModel): elif isinstance(self.controlnet, MultiControlNetModel):
if not isinstance(image, list): if not isinstance(image, list):
raise TypeError("For multiple controlnets: `image` must be type `list`") raise TypeError("For multiple controlnets: `image` must be type `list`")
if len(image) != len(self.controlnet.nets): # When `image` is a nested list:
# (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])
elif any(isinstance(i, list) for i in image):
raise ValueError("A single batch of multiple conditionings are supported at the moment.")
elif len(image) != len(self.controlnet.nets):
raise ValueError( raise ValueError(
"For multiple controlnets: `image` must have the same length as the number of controlnets." "For multiple controlnets: `image` must have the same length as the number of controlnets."
) )
...@@ -556,12 +568,14 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline): ...@@ -556,12 +568,14 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline):
assert False assert False
# Check `controlnet_conditioning_scale` # Check `controlnet_conditioning_scale`
if isinstance(self.controlnet, ControlNetModel): if isinstance(self.controlnet, ControlNetModel):
if not isinstance(controlnet_conditioning_scale, float): if not isinstance(controlnet_conditioning_scale, float):
raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
elif isinstance(self.controlnet, MultiControlNetModel): elif isinstance(self.controlnet, MultiControlNetModel):
if isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( if isinstance(controlnet_conditioning_scale, list):
if any(isinstance(i, list) for i in controlnet_conditioning_scale):
raise ValueError("A single batch of multiple conditionings are supported at the moment.")
elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
self.controlnet.nets self.controlnet.nets
): ):
raise ValueError( raise ValueError(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment