Unverified Commit 58fc8244 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

add: better warning messages when handling multiple conditionings. (#2804)

* add: better warning messages when handling multiple conditioning.

* fix: handling of controlnet_conditioning_scale
parent fab4f3d6
......@@ -537,15 +537,27 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline):
f" {negative_prompt_embeds.shape}."
)
# Check `image`
# `prompt` needs more sophisticated handling when there are multiple
# conditionings.
if isinstance(self.controlnet, MultiControlNetModel):
if isinstance(prompt, list):
logger.warning(
f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}"
" prompts. The conditionings will be fixed across the prompts."
)
# Check `image`
if isinstance(self.controlnet, ControlNetModel):
self.check_image(image, prompt, prompt_embeds)
elif isinstance(self.controlnet, MultiControlNetModel):
if not isinstance(image, list):
raise TypeError("For multiple controlnets: `image` must be type `list`")
if len(image) != len(self.controlnet.nets):
# When `image` is a nested list:
# (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])
elif any(isinstance(i, list) for i in image):
raise ValueError("A single batch of multiple conditionings are supported at the moment.")
elif len(image) != len(self.controlnet.nets):
raise ValueError(
"For multiple controlnets: `image` must have the same length as the number of controlnets."
)
......@@ -556,12 +568,14 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline):
assert False
# Check `controlnet_conditioning_scale`
if isinstance(self.controlnet, ControlNetModel):
if not isinstance(controlnet_conditioning_scale, float):
raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
elif isinstance(self.controlnet, MultiControlNetModel):
if isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
if isinstance(controlnet_conditioning_scale, list):
if any(isinstance(i, list) for i in controlnet_conditioning_scale):
raise ValueError("A single batch of multiple conditionings are supported at the moment.")
elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
self.controlnet.nets
):
raise ValueError(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment