Unverified Commit 613e77f8 authored by CyberVy's avatar CyberVy Committed by GitHub
Browse files

Fix Callback Tensor Inputs of the SDXL Controlnet Inpaint and Img2img...

Fix Callback Tensor Inputs of the SDXL Controlnet Inpaint and Img2img Pipelines are missing "controlnet_image". (#10880)

* Update pipeline_controlnet_inpaint_sd_xl.py

* Update pipeline_controlnet_sd_xl_img2img.py

* Update pipeline_controlnet_union_inpaint_sd_xl.py

* Update pipeline_controlnet_union_sd_xl_img2img.py

* Update pipeline_controlnet_inpaint_sd_xl.py

* Update pipeline_controlnet_sd_xl_img2img.py

* Update pipeline_controlnet_union_inpaint_sd_xl.py

* Update pipeline_controlnet_union_sd_xl_img2img.py

* Apply make style and make fix-copies fixes

* Update geodiff_molecule_conformation.ipynb

* Delete examples/research_projects/geodiff/geodiff_molecule_conformation.ipynb

* Delete examples/research_projects/gligen/demo.ipynb

* Create geodiff_molecule_conformation.ipynb

* Create demo.ipynb

* Update geodiff_molecule_conformation.ipynb

* Update geodiff_molecule_conformation.ipynb

* Delete examples/research_projects/geodiff/geodiff_molecule_conformation.ipynb

* Add files via upload

* Delete src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py

* Add files via upload
parent 1450c2ac
...@@ -237,6 +237,7 @@ class StableDiffusionXLControlNetInpaintPipeline( ...@@ -237,6 +237,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
"add_neg_time_ids", "add_neg_time_ids",
"mask", "mask",
"masked_image_latents", "masked_image_latents",
"control_image",
] ]
def __init__( def __init__(
...@@ -743,7 +744,7 @@ class StableDiffusionXLControlNetInpaintPipeline( ...@@ -743,7 +744,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
if padding_mask_crop is not None: if padding_mask_crop is not None:
if not isinstance(image, PIL.Image.Image): if not isinstance(image, PIL.Image.Image):
raise ValueError( raise ValueError(
f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}." f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
) )
if not isinstance(mask_image, PIL.Image.Image): if not isinstance(mask_image, PIL.Image.Image):
raise ValueError( raise ValueError(
...@@ -751,7 +752,7 @@ class StableDiffusionXLControlNetInpaintPipeline( ...@@ -751,7 +752,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
f" {type(mask_image)}." f" {type(mask_image)}."
) )
if output_type != "pil": if output_type != "pil":
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.") raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
if prompt_embeds is not None and pooled_prompt_embeds is None: if prompt_embeds is not None and pooled_prompt_embeds is None:
raise ValueError( raise ValueError(
...@@ -1644,7 +1645,7 @@ class StableDiffusionXLControlNetInpaintPipeline( ...@@ -1644,7 +1645,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input." " `pipeline.unet` or your `mask_image` or `image` input."
) )
elif num_channels_unet != 4: elif num_channels_unet != 4:
...@@ -1835,6 +1836,7 @@ class StableDiffusionXLControlNetInpaintPipeline( ...@@ -1835,6 +1836,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
latents = callback_outputs.pop("latents", latents) latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
control_image = callback_outputs.pop("control_image", control_image)
# call the callback, if provided # call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
......
...@@ -242,6 +242,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline( ...@@ -242,6 +242,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
"add_time_ids", "add_time_ids",
"negative_pooled_prompt_embeds", "negative_pooled_prompt_embeds",
"add_neg_time_ids", "add_neg_time_ids",
"control_image",
] ]
def __init__( def __init__(
...@@ -1614,6 +1615,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline( ...@@ -1614,6 +1615,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
) )
add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids) add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids)
control_image = callback_outputs.pop("control_image", control_image)
# call the callback, if provided # call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
......
...@@ -219,6 +219,7 @@ class StableDiffusionXLControlNetUnionInpaintPipeline( ...@@ -219,6 +219,7 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
"add_time_ids", "add_time_ids",
"mask", "mask",
"masked_image_latents", "masked_image_latents",
"control_image",
] ]
def __init__( def __init__(
...@@ -726,7 +727,7 @@ class StableDiffusionXLControlNetUnionInpaintPipeline( ...@@ -726,7 +727,7 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
if padding_mask_crop is not None: if padding_mask_crop is not None:
if not isinstance(image, PIL.Image.Image): if not isinstance(image, PIL.Image.Image):
raise ValueError( raise ValueError(
f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}." f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
) )
if not isinstance(mask_image, PIL.Image.Image): if not isinstance(mask_image, PIL.Image.Image):
raise ValueError( raise ValueError(
...@@ -734,7 +735,7 @@ class StableDiffusionXLControlNetUnionInpaintPipeline( ...@@ -734,7 +735,7 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
f" {type(mask_image)}." f" {type(mask_image)}."
) )
if output_type != "pil": if output_type != "pil":
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.") raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
if prompt_embeds is not None and pooled_prompt_embeds is None: if prompt_embeds is not None and pooled_prompt_embeds is None:
raise ValueError( raise ValueError(
...@@ -1743,6 +1744,7 @@ class StableDiffusionXLControlNetUnionInpaintPipeline( ...@@ -1743,6 +1744,7 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
latents = callback_outputs.pop("latents", latents) latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
control_image = callback_outputs.pop("control_image", control_image)
# call the callback, if provided # call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
......
...@@ -252,12 +252,7 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline( ...@@ -252,12 +252,7 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
"feature_extractor", "feature_extractor",
"image_encoder", "image_encoder",
] ]
_callback_tensor_inputs = [ _callback_tensor_inputs = ["latents", "prompt_embeds", "add_text_embeds", "add_time_ids", "control_image"]
"latents",
"prompt_embeds",
"add_text_embeds",
"add_time_ids",
]
def __init__( def __init__(
self, self,
...@@ -1562,6 +1557,7 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline( ...@@ -1562,6 +1557,7 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
control_image = callback_outputs.pop("control_image", control_image)
# call the callback, if provided # call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment